@@ -150,7 +150,6 @@ def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str], is_weig
150
150
"F64" : torch .float64 ,
151
151
"I64" : torch .int64 ,
152
152
"F8_E4M3" : torch .float8_e4m3fn ,
153
- "F8_E5M2" : torch .float8_e5m2 ,
154
153
}
155
154
156
155
@@ -526,43 +525,6 @@ def partition_tensor(self, param, empty_param, param_type, param_casting_dtype,
526
525
return param
527
526
528
527
529
- class ReduceFromModelParallelRegion (torch .autograd .Function ):
530
- """
531
- All-reduce in forward pass, identity in backward pass.
532
- This is the `g` function in the paper: https://arxiv.org/abs/1909.08053
533
- """
534
-
535
- @staticmethod
536
- def forward (ctx , x , device_mesh ):
537
- if device_mesh .size () == 1 :
538
- return x
539
- dist .all_reduce (x , op = dist .ReduceOp .SUM , group = device_mesh .get_group ())
540
- return x
541
-
542
- @staticmethod
543
- def backward (ctx , grad_output ):
544
- return grad_output
545
-
546
-
547
- class CopyToModelParallelRegion (torch .autograd .Function ):
548
- """
549
- Copy in forward pass, all-reduce in backward pass.
550
- This is the `f` function in the paper: https://arxiv.org/abs/1909.08053
551
- """
552
-
553
- @staticmethod
554
- def forward (ctx , x , device_mesh ):
555
- ctx .device_mesh = device_mesh
556
- return x
557
-
558
- @staticmethod
559
- def backward (ctx , grad_output ):
560
- if ctx .device_mesh .size () == 1 :
561
- return grad_output
562
- dist .all_reduce (grad_output , op = dist .ReduceOp .SUM , group = ctx .device_mesh .get_group ())
563
- return grad_output
564
-
565
-
566
528
class ColwiseParallel (TensorParallelLayer ):
567
529
"""
568
530
General tensor parallel layer for transformers.
@@ -585,8 +547,15 @@ def __init__(
585
547
586
548
@staticmethod
587
549
def _prepare_input_fn (input_layouts , desired_input_layouts , mod , inputs , device_mesh ):
550
+ # TODO: figure out dynamo support for instance method and switch this to instance method
588
551
# annotate module input placements/sharding with input_layouts
589
552
input_tensor = inputs [0 ]
553
+ if not isinstance (input_tensor , DTensor ):
554
+ input_tensor = DTensor .from_local (input_tensor , device_mesh , input_layouts , run_check = False )
555
+
556
+ # transform the input layouts to the desired layouts of ColwiseParallel
557
+ if input_layouts != desired_input_layouts :
558
+ input_tensor = input_tensor .redistribute (placements = desired_input_layouts , async_op = False )
590
559
return input_tensor
591
560
592
561
def partition_tensor (self , param , empty_param , param_type , param_casting_dtype , to_contiguous , rank , device_mesh ):
@@ -595,19 +564,41 @@ def partition_tensor(self, param, empty_param, param_type, param_casting_dtype,
595
564
# weight would become Shard(1)
596
565
if param_type == "bias" :
597
566
parameter = get_tensor_shard (param , empty_param , device_mesh , rank , - 1 )
567
+ shard = [Shard (- 1 )]
598
568
else :
569
+ shard = [Shard (- 2 )]
599
570
parameter = get_tensor_shard (param , empty_param , device_mesh , rank , - 2 )
600
571
601
572
parameter = parameter .to (param_casting_dtype )
602
573
if to_contiguous :
603
574
parameter = parameter .contiguous ()
604
-
575
+ if self .use_dtensor :
576
+ parameter = DTensor .from_local (
577
+ parameter , device_mesh , shard , run_check = False , shape = empty_param .size (), stride = empty_param .stride ()
578
+ )
605
579
return nn .Parameter (parameter , requires_grad = parameter .is_floating_point ())
606
580
607
581
@staticmethod
608
582
def _prepare_output_fn (output_layouts , use_local_output , mod , outputs , device_mesh ):
609
- outputs = CopyToModelParallelRegion .apply (outputs , device_mesh )
610
- return outputs
583
+ # outputs is a shard on last dimension DTensor, i.e. Shard(-1)
584
+ if outputs .placements != output_layouts :
585
+ outputs = outputs .redistribute (placements = output_layouts , async_op = False )
586
+ # back to local tensor
587
+ return outputs .to_local () if use_local_output and isinstance (outputs , DTensor ) else outputs
588
+
589
+
590
+ class PackedColwiseParallel (ColwiseParallel ):
591
+ def partition_tensor (self , param , empty_param , param_type , param_casting_dtype , to_contiguous , rank , device_mesh ):
592
+ # colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
593
+ # means Colwise as Linear is input * weight^T + bias, where
594
+ # weight would become Shard(1)
595
+ parameter = get_packed_weights (param , empty_param , device_mesh , rank , - 2 )
596
+ parameter = parameter .to (param_casting_dtype )
597
+ if to_contiguous :
598
+ parameter = parameter .contiguous ()
599
+ if self .use_dtensor :
600
+ parameter = DTensor .from_local (parameter , device_mesh , [Shard (- 2 )], run_check = False )
601
+ return nn .Parameter (parameter , requires_grad = parameter .is_floating_point ())
611
602
612
603
613
604
class RowwiseParallel (TensorParallelLayer ):
@@ -644,15 +635,23 @@ def __init__(
644
635
self .use_dtensor = use_dtensor
645
636
646
637
def partition_tensor (self , param , empty_param , param_type , param_casting_dtype , to_contiguous , rank , device_mesh ):
647
- if param_type == "bias" :
648
- parameter = param [:]
649
- else :
638
+ # Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1)
639
+ # means Rowwise as nn.Linear is input * weight^T + bias, where
640
+ # weight would become Shard(0)
641
+ if param_type != "bias" :
650
642
parameter = get_tensor_shard (param , empty_param , device_mesh , rank , - 1 )
643
+ shard = [Shard (- 1 )]
644
+ else :
645
+ shard = [Replicate ()]
646
+ parameter = param [:]
651
647
652
648
parameter = parameter .to (param_casting_dtype )
653
649
if to_contiguous :
654
650
parameter = parameter .contiguous ()
655
-
651
+ if self .use_dtensor :
652
+ parameter = DTensor .from_local (
653
+ parameter , device_mesh , shard , run_check = False , shape = empty_param .size (), stride = empty_param .stride ()
654
+ )
656
655
return nn .Parameter (parameter , requires_grad = parameter .is_floating_point ())
657
656
658
657
@staticmethod
@@ -662,13 +661,24 @@ def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_
662
661
mod .bias = None
663
662
664
663
input_tensor = inputs [0 ]
664
+ if not isinstance (input_tensor , DTensor ):
665
+ input_tensor = DTensor .from_local (input_tensor , device_mesh , input_layouts , run_check = False )
666
+
667
+ if input_layouts != desired_input_layouts :
668
+ input_tensor = input_tensor .redistribute (placements = desired_input_layouts , async_op = True )
665
669
return input_tensor
666
670
667
671
@staticmethod
668
672
def _prepare_output_fn (output_layouts , use_local_output , mod , outputs , device_mesh ):
669
- outputs = ReduceFromModelParallelRegion .apply (outputs , device_mesh )
673
+ # Rowwise sharding produces partial output, depending on output layouts:
674
+ # 1. to replicate -> allreduce
675
+ # 2. to shard -> reduce_scatter
676
+ if outputs .placements != output_layouts :
677
+ outputs = outputs .redistribute (placements = output_layouts , async_op = True )
678
+ outputs = outputs .to_local () # otherwise the `+=` op will gather
670
679
if hasattr (mod , "_bias" ):
671
680
outputs += mod ._bias
681
+ # back to local tensor if use_local_output is True
672
682
return outputs
673
683
674
684
def prepare_module_tp (self , module : nn .Module , device_mesh ) -> nn .Module :
@@ -694,21 +704,6 @@ def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
694
704
)
695
705
696
706
697
- class PackedColwiseParallel (ColwiseParallel ):
698
- def partition_tensor (self , param , empty_param , param_type , param_casting_dtype , to_contiguous , rank , device_mesh ):
699
- # NOTE(3outeille): need to be deprecated as no longer using dtensors
700
- # colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
701
- # means Colwise as Linear is input * weight^T + bias, where
702
- # weight would become Shard(1)
703
- parameter = get_packed_weights (param , empty_param , device_mesh , rank , - 2 )
704
- parameter = parameter .to (param_casting_dtype )
705
- if to_contiguous :
706
- parameter = parameter .contiguous ()
707
- if self .use_dtensor :
708
- parameter = DTensor .from_local (parameter , device_mesh , [Shard (- 2 )], run_check = False )
709
- return nn .Parameter (parameter , requires_grad = parameter .is_floating_point ())
710
-
711
-
712
707
class PackedRowwiseParallel (RowwiseParallel ):
713
708
def partition_tensor (self , param , empty_param , param_type , param_casting_dtype , to_contiguous , rank , device_mesh ):
714
709
# colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
0 commit comments