Skip to content

Commit 20ce210

Browse files
authored
Revert "remove dtensors, not explicit (#39840)" (#39912)
* Revert "remove dtensors, not explicit (#39840)" This did not work with generation (lm_head needs extra care!) This reverts commit 6dfd561. * update * style?
1 parent 2589a52 commit 20ce210

File tree

4 files changed

+77
-73
lines changed

4 files changed

+77
-73
lines changed

conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def check_output(self, want, got, optionflags):
130130

131131
if is_torch_available():
132132
import torch
133+
133134
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
134135
# We set it to `False` for CI. See https://github.com/pytorch/pytorch/issues/157274#issuecomment-3090791615
135136
torch.backends.cudnn.allow_tf32 = False

src/transformers/integrations/tensor_parallel.py

Lines changed: 56 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,6 @@ def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str], is_weig
150150
"F64": torch.float64,
151151
"I64": torch.int64,
152152
"F8_E4M3": torch.float8_e4m3fn,
153-
"F8_E5M2": torch.float8_e5m2,
154153
}
155154

156155

@@ -526,43 +525,6 @@ def partition_tensor(self, param, empty_param, param_type, param_casting_dtype,
526525
return param
527526

528527

529-
class ReduceFromModelParallelRegion(torch.autograd.Function):
530-
"""
531-
All-reduce in forward pass, identity in backward pass.
532-
This is the `g` function in the paper: https://arxiv.org/abs/1909.08053
533-
"""
534-
535-
@staticmethod
536-
def forward(ctx, x, device_mesh):
537-
if device_mesh.size() == 1:
538-
return x
539-
dist.all_reduce(x, op=dist.ReduceOp.SUM, group=device_mesh.get_group())
540-
return x
541-
542-
@staticmethod
543-
def backward(ctx, grad_output):
544-
return grad_output
545-
546-
547-
class CopyToModelParallelRegion(torch.autograd.Function):
548-
"""
549-
Copy in forward pass, all-reduce in backward pass.
550-
This is the `f` function in the paper: https://arxiv.org/abs/1909.08053
551-
"""
552-
553-
@staticmethod
554-
def forward(ctx, x, device_mesh):
555-
ctx.device_mesh = device_mesh
556-
return x
557-
558-
@staticmethod
559-
def backward(ctx, grad_output):
560-
if ctx.device_mesh.size() == 1:
561-
return grad_output
562-
dist.all_reduce(grad_output, op=dist.ReduceOp.SUM, group=ctx.device_mesh.get_group())
563-
return grad_output
564-
565-
566528
class ColwiseParallel(TensorParallelLayer):
567529
"""
568530
General tensor parallel layer for transformers.
@@ -585,8 +547,15 @@ def __init__(
585547

586548
@staticmethod
587549
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
550+
# TODO: figure out dynamo support for instance method and switch this to instance method
588551
# annotate module input placements/sharding with input_layouts
589552
input_tensor = inputs[0]
553+
if not isinstance(input_tensor, DTensor):
554+
input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)
555+
556+
# transform the input layouts to the desired layouts of ColwiseParallel
557+
if input_layouts != desired_input_layouts:
558+
input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=False)
590559
return input_tensor
591560

592561
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
@@ -595,19 +564,41 @@ def partition_tensor(self, param, empty_param, param_type, param_casting_dtype,
595564
# weight would become Shard(1)
596565
if param_type == "bias":
597566
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1)
567+
shard = [Shard(-1)]
598568
else:
569+
shard = [Shard(-2)]
599570
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -2)
600571

601572
parameter = parameter.to(param_casting_dtype)
602573
if to_contiguous:
603574
parameter = parameter.contiguous()
604-
575+
if self.use_dtensor:
576+
parameter = DTensor.from_local(
577+
parameter, device_mesh, shard, run_check=False, shape=empty_param.size(), stride=empty_param.stride()
578+
)
605579
return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
606580

607581
@staticmethod
608582
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
609-
outputs = CopyToModelParallelRegion.apply(outputs, device_mesh)
610-
return outputs
583+
# outputs is a shard on last dimension DTensor, i.e. Shard(-1)
584+
if outputs.placements != output_layouts:
585+
outputs = outputs.redistribute(placements=output_layouts, async_op=False)
586+
# back to local tensor
587+
return outputs.to_local() if use_local_output and isinstance(outputs, DTensor) else outputs
588+
589+
590+
class PackedColwiseParallel(ColwiseParallel):
591+
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
592+
# colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
593+
# means Colwise as Linear is input * weight^T + bias, where
594+
# weight would become Shard(1)
595+
parameter = get_packed_weights(param, empty_param, device_mesh, rank, -2)
596+
parameter = parameter.to(param_casting_dtype)
597+
if to_contiguous:
598+
parameter = parameter.contiguous()
599+
if self.use_dtensor:
600+
parameter = DTensor.from_local(parameter, device_mesh, [Shard(-2)], run_check=False)
601+
return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
611602

612603

613604
class RowwiseParallel(TensorParallelLayer):
@@ -644,15 +635,23 @@ def __init__(
644635
self.use_dtensor = use_dtensor
645636

646637
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
647-
if param_type == "bias":
648-
parameter = param[:]
649-
else:
638+
# Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1)
639+
# means Rowwise as nn.Linear is input * weight^T + bias, where
640+
# weight would become Shard(0)
641+
if param_type != "bias":
650642
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1)
643+
shard = [Shard(-1)]
644+
else:
645+
shard = [Replicate()]
646+
parameter = param[:]
651647

652648
parameter = parameter.to(param_casting_dtype)
653649
if to_contiguous:
654650
parameter = parameter.contiguous()
655-
651+
if self.use_dtensor:
652+
parameter = DTensor.from_local(
653+
parameter, device_mesh, shard, run_check=False, shape=empty_param.size(), stride=empty_param.stride()
654+
)
656655
return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
657656

658657
@staticmethod
@@ -662,13 +661,24 @@ def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_
662661
mod.bias = None
663662

664663
input_tensor = inputs[0]
664+
if not isinstance(input_tensor, DTensor):
665+
input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)
666+
667+
if input_layouts != desired_input_layouts:
668+
input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True)
665669
return input_tensor
666670

667671
@staticmethod
668672
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
669-
outputs = ReduceFromModelParallelRegion.apply(outputs, device_mesh)
673+
# Rowwise sharding produces partial output, depending on output layouts:
674+
# 1. to replicate -> allreduce
675+
# 2. to shard -> reduce_scatter
676+
if outputs.placements != output_layouts:
677+
outputs = outputs.redistribute(placements=output_layouts, async_op=True)
678+
outputs = outputs.to_local() # otherwise the `+=` op will gather
670679
if hasattr(mod, "_bias"):
671680
outputs += mod._bias
681+
# back to local tensor if use_local_output is True
672682
return outputs
673683

674684
def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
@@ -694,21 +704,6 @@ def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
694704
)
695705

696706

697-
class PackedColwiseParallel(ColwiseParallel):
698-
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
699-
# NOTE(3outeille): need to be deprecated as no longer using dtensors
700-
# colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
701-
# means Colwise as Linear is input * weight^T + bias, where
702-
# weight would become Shard(1)
703-
parameter = get_packed_weights(param, empty_param, device_mesh, rank, -2)
704-
parameter = parameter.to(param_casting_dtype)
705-
if to_contiguous:
706-
parameter = parameter.contiguous()
707-
if self.use_dtensor:
708-
parameter = DTensor.from_local(parameter, device_mesh, [Shard(-2)], run_check=False)
709-
return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
710-
711-
712707
class PackedRowwiseParallel(RowwiseParallel):
713708
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
714709
# colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)

src/transformers/modeling_utils.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4087,16 +4087,9 @@ def save_pretrained(
40874087
for shard_file, tensors in filename_to_tensors:
40884088
shard = {}
40894089
for tensor in tensors:
4090-
if _is_dtensor_available and getattr(self, "_device_mesh", None) is not None:
4091-
plan = _get_parameter_tp_plan(tensor, self._tp_plan)
4092-
full_tensor = state_dict[tensor]
4093-
if isinstance(state_dict[tensor], DTensor):
4094-
full_tensor = full_tensor.full_tensor()
4095-
elif plan is not None:
4096-
shard_dim = -1 if "rowwise" in plan else 0
4097-
gather_list = [torch.empty_like(full_tensor) for _ in range(self._device_mesh.size())]
4098-
torch.distributed.all_gather(gather_list, full_tensor)
4099-
full_tensor = torch.cat(gather_list, dim=shard_dim)
4090+
if _is_dtensor_available and isinstance(state_dict[tensor], DTensor):
4091+
full_tensor = state_dict[tensor].full_tensor()
4092+
# to get the correctly ordered tensor we need to repack if packed
41004093
if _get_parameter_tp_plan(tensor, self._tp_plan) in ("local_packed_rowwise",):
41014094
full_tensor = repack_weights(full_tensor, -1, self._tp_size, 2)
41024095
shard[tensor] = full_tensor.contiguous() # only do contiguous after it's permuted correctly

tests/tensor_parallel/test_tensor_parallel.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,14 @@ def test_model_forward(self):
101101
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", tp_plan="auto")
102102
torch.distributed.barrier()
103103
104+
has_dtensor = 0
105+
for name, parameter in model.named_parameters():
106+
if isinstance(parameter.data, torch.distributed.tensor.DTensor):
107+
has_dtensor = 1
108+
break
109+
110+
assert has_dtensor == 1, "TP model must has DTensor"
111+
104112
tokenizer = AutoTokenizer.from_pretrained(model_id, legacy=False)
105113
prompt = "Can I help"
106114
@@ -110,8 +118,7 @@ def test_model_forward(self):
110118
next_token_logits = outputs[0][:, -1, :]
111119
next_token = torch.argmax(next_token_logits, dim=-1)
112120
response = tokenizer.decode(next_token)
113-
print(response)
114-
# assert response == "with"
121+
assert response == "with"
115122
116123
torch.distributed.barrier()
117124
torch.distributed.destroy_process_group()
@@ -136,6 +143,14 @@ def test_model_generate(self):
136143
137144
model.forward = torch.compile(model.forward)
138145
146+
has_dtensor = 0
147+
for name, parameter in model.named_parameters():
148+
if isinstance(parameter.data, torch.distributed.tensor.DTensor):
149+
has_dtensor = 1
150+
break
151+
152+
assert has_dtensor == 1, "TP model must has DTensor"
153+
139154
tokenizer = AutoTokenizer.from_pretrained(model_id)
140155
prompt = "Can I help"
141156

0 commit comments

Comments
 (0)