@@ -398,7 +398,7 @@ def test_all_to_all_vdev_2d(self, align: int) -> None:
398398 nsplits , dtype = torch .int64 , device = self .device
399399 ).copy_ (inp_splits )
400400 # 2 rows: output splits, output offsets
401- # Initiallizing all values to -1 to check if they are updated
401+ # Initializing all values to -1 to check if they are updated
402402 out_splits_offsets = symm_mem .empty (
403403 (2 , nsplits ), dtype = torch .int64 , device = self .device
404404 ).fill_ (- 1 )
@@ -503,7 +503,7 @@ def test_all_to_all_vdev_2d_offset(self) -> None:
503503 (2 , nsplits ), dtype = torch .int64 , device = self .device
504504 )
505505 # 2 rows: output splits, output offsets
506- # Initiallizing all values to -1 to check if they are updated
506+ # Initializing all values to -1 to check if they are updated
507507 out_splits_offsets = symm_mem .empty (
508508 (2 , nsplits ), dtype = torch .int64 , device = self .device
509509 ).fill_ (- 1 )
@@ -617,15 +617,15 @@ def dispatch_then_combine(device, align: int, group) -> None:
617617 inp_splits
618618 )
619619 # 2 rows: output splits, output offsets
620- # Initiallizing all values to -1 to check if they are updated
620+ # Initializing all values to -1 to check if they are updated
621621 out_splits_offsets = symm_mem .empty (
622622 (2 , nsplits ), dtype = torch .int64 , device = device
623623 ).fill_ (- 1 )
624624
625625 # Buffers for combine
626626 combine_out = symm_mem .empty (max_out_numel , dtype = dtype , device = device ).fill_ (- 1 )
627627 # 2 rows: output splits, output offsets
628- # Initiallizing all values to -1 to check if they are updated
628+ # Initializing all values to -1 to check if they are updated
629629 combine_out_splits_offsets = symm_mem .empty (
630630 (2 , nsplits ), dtype = torch .int64 , device = device
631631 ).fill_ (- 1 )
0 commit comments