Skip to content

Commit 70c4d21

Browse files
JZ-LIANGpkuzyc
andauthored
[Auto-Parallel] Reshard API & Hybrid Parallel Unitest for dy2static mode (#59856)
* stsatic_decorate v0.1 * update static_decorate as comments * add unit tests and adapt placement api * add docs for the api * remove useless print and comments * first commit * stsatic_decorate v0.1 * update static_decorate as comments * add unit tests and adapt placement api * add docs for the api * remove useless print and comments * add unit execution code * fix sample code for static_decorate * add get_program interface in DistModel * modify as suggested * move the init parameters part to helper.py * fix unittest name in CMakeList * add api * add unitest * typoes * add dy2static test case for llama * static dist model pp * bug fixed * enable all test * fix typoes --------- Co-authored-by: Yichen Zhang <[email protected]>
1 parent 5be87ba commit 70c4d21

File tree

10 files changed

+593
-28
lines changed

10 files changed

+593
-28
lines changed

python/paddle/distributed/auto_parallel/api.py

Lines changed: 74 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,26 @@
1717
import paddle
1818
import paddle.distributed as dist
1919
from paddle import nn
20-
from paddle.base.framework import EagerParamBase
20+
from paddle.base import unique_name
21+
from paddle.base.framework import (
22+
EagerParamBase,
23+
Variable,
24+
default_main_program,
25+
)
2126
from paddle.distributed.auto_parallel import Engine
2227
from paddle.distributed.auto_parallel.interface import (
2328
shard_tensor as shard_tensor_static,
2429
)
30+
from paddle.distributed.auto_parallel.static.completion import (
31+
mark_as_sharding_propagation_skip_op,
32+
)
33+
from paddle.distributed.auto_parallel.static.dist_context import (
34+
get_default_distributed_context,
35+
)
36+
from paddle.distributed.auto_parallel.static.dist_op import DistributedOperator
37+
from paddle.distributed.auto_parallel.static.utils import (
38+
convert_to_dims_mapping,
39+
)
2540
from paddle.framework import core
2641

2742
from .placement_type import get_shard_spec
@@ -131,6 +146,7 @@ def __init__(
131146
)
132147
self._mode = None
133148
self._feed_name_list = {}
149+
134150
# convert dygraph model to static model
135151
batch_size = loader.batch_sampler.batch_size
136152
inputs_spec, labels_spec = self._engine._prepare_data_spec(
@@ -268,15 +284,20 @@ def __call__(self, *args):
268284
raise ValueError("Please set loss function before evaluation.")
269285
feeds = self._make_feeds(list(args))
270286
outs = self._engine.run(feeds)
287+
271288
if self._mode == "predict":
272-
return outs["outputs"]
289+
if "outputs" in outs:
290+
return outs["outputs"]
291+
else:
292+
return None
273293
else:
274-
return outs["loss"]
294+
if "loss" in outs:
295+
return outs["loss"]
296+
else:
297+
return None
275298

276299

277300
# Part2: DistTensor construction related APIs
278-
279-
280301
def to_static(
281302
layer: paddle.nn.Layer,
282303
loader=None,
@@ -566,10 +587,55 @@ def reshard(dist_tensor, mesh, placements):
566587

567588
return paddle.base.core.reshard(dist_tensor, dist_attr)
568589
else:
569-
# TODO(GhostScreaming): Support static DistTensor later.
570-
raise RuntimeError(
571-
"paddle.dist.reshard only support dynamic graph now. It will be supported for static graph later."
590+
assert isinstance(
591+
dist_tensor, Variable
592+
), "in dy2static mode, reshard's input should be Variable, but got [{}]".format(
593+
dist_tensor
594+
)
595+
sharding_specs = get_shard_spec(mesh, placements, dist_tensor.ndim)
596+
main_program = default_main_program()
597+
default_dist_ctx = get_default_distributed_context()
598+
599+
# output variable
600+
out_var = main_program.current_block().create_var(
601+
name=unique_name.generate_with_ignorable_key(
602+
".".join(['reshard_api', 'tmp'])
603+
),
604+
dtype=dist_tensor.dtype,
605+
shape=dist_tensor.shape,
606+
type=dist_tensor.type,
607+
persistable=dist_tensor.persistable,
608+
stop_gradient=dist_tensor.stop_gradient,
609+
)
610+
611+
# transition op
612+
# optimization in future to remove redundant D2D memory copy
613+
target_dims_mapping = convert_to_dims_mapping(sharding_specs, mesh)
614+
trans_op = main_program.current_block().append_op(
615+
type='assign',
616+
inputs={'X': [dist_tensor]},
617+
outputs={'Out': [out_var]},
618+
)
619+
dist_op = DistributedOperator(trans_op)
620+
dist_op.dist_attr.process_mesh = mesh
621+
dist_op.dist_attr.mark_annotated("process_mesh")
622+
dist_op.dist_attr.chunk_id = 0
623+
624+
input_dist_attr = dist_op.dist_attr.get_input_dist_attr(
625+
dist_tensor.name
572626
)
627+
input_dist_attr.dims_mapping = target_dims_mapping
628+
input_dist_attr.mark_annotated("dims_mapping")
629+
output_dist_attr = dist_op.dist_attr.get_output_dist_attr(out_var.name)
630+
output_dist_attr.dims_mapping = target_dims_mapping
631+
output_dist_attr.mark_annotated("dims_mapping")
632+
633+
default_dist_ctx.add_dist_op_for_program(dist_op)
634+
mark_as_sharding_propagation_skip_op(trans_op)
635+
# trans_op = shard_op_static(paddle.assign, mesh, [sharding_specs])
636+
# out_var = trans_op(dist_tensor)
637+
638+
return out_var
573639

574640

575641
def shard_layer(

python/paddle/distributed/auto_parallel/static/completion.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616
import logging
1717
import os
1818

19+
import paddle
1920
from paddle.base.core import ( # noqa: F401
2021
contains_spmd_rule,
2122
get_phi_spmd_rule,
2223
get_spmd_rule,
2324
)
25+
from paddle.base.framework import Operator
2426
from paddle.base.log_helper import get_logger
2527
from paddle.distributed.fleet.meta_optimizers.common import OpRole
2628
from paddle.framework import core
@@ -53,6 +55,24 @@
5355
"read",
5456
]
5557

58+
_skip_propagation_prefix = "Auto_Parallel_Completion_Skipped"
59+
60+
61+
def mark_as_sharding_propagation_skip_op(op):
62+
op._set_attr('op_namescope', '/' + _skip_propagation_prefix)
63+
64+
65+
def is_sharding_propagation_skip_op(op):
66+
if isinstance(op, paddle.base.libpaddle.OpDesc):
67+
op_desc = op
68+
elif isinstance(op, Operator):
69+
op_desc = op.desc
70+
else:
71+
raise RuntimeError(f"static mode operator is expected but got [{op}]")
72+
return op_desc.has_attr(
73+
"op_namescope"
74+
) and _skip_propagation_prefix in op_desc.attr("op_namescope")
75+
5676

5777
def compute_compatible_dim_mapping(dim_mapping_list):
5878
"""Compute the compatible dim mapping given a list of dim mapping."""
@@ -218,6 +238,7 @@ def _update_tensor_node_dims_mapping(self, tensor_node, fwd=True):
218238
or pred_op_node.op().type()
219239
== "create_double_buffer_reader"
220240
or pred_op_node.op().type() == "read"
241+
# or is_sharding_propagation_skip_op(pred_op_node.op()) # reshard should only fwd tensor propagation
221242
):
222243
continue
223244
op_dist_attr = (
@@ -255,6 +276,7 @@ def _update_tensor_node_dims_mapping(self, tensor_node, fwd=True):
255276
or succ_op_node.op().type()
256277
== "create_double_buffer_reader"
257278
or succ_op_node.op().type() == "read"
279+
or is_sharding_propagation_skip_op(succ_op_node.op())
258280
):
259281
continue
260282
op_dist_attr = (
@@ -293,7 +315,10 @@ def _update_op_node_dims_mapping(self, op_node, fwd=True):
293315
if (not op_node.is_op()) or (op_node.op() is None):
294316
return False
295317
# Skip reader op
296-
if op_desc.type() in __skip_dims_mapping_op__:
318+
if (
319+
op_desc.type() in __skip_dims_mapping_op__
320+
or is_sharding_propagation_skip_op(op_node.op())
321+
):
297322
return False
298323

299324
dist_op = self._dist_context.get_dist_op_for_graph(op_node)

test/auto_parallel/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,10 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
178178
test_semi_auto_parallel_dist_to_static)
179179
set_tests_properties(test_semi_auto_parallel_dist_to_static
180180
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 300)
181+
py_test_modules(test_static_reshard_api MODULES test_static_reshard_api)
182+
set_tests_properties(test_static_reshard_api
183+
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 300)
184+
181185
# End of unittests WITH multi cards and timeout
182186

183187
# NOTE(zyl): unittests WITH multi cards and WITHOUT timeout

test/auto_parallel/hybrid_strategy/semi_auto_llama.py

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@
1616
from functools import reduce
1717

1818
import numpy as np
19-
from semi_auto_parallel_llama_model import LlamaForCausalLMAuto, set_global_mesh
19+
from semi_auto_parallel_llama_model import (
20+
LlamaForCausalLMAuto,
21+
LlamaPretrainingCriterionAuto,
22+
set_global_mesh,
23+
)
2024

2125
import paddle
2226
import paddle.distributed as dist
@@ -104,8 +108,9 @@ def init_dist_env(self):
104108
global_mesh = dist.ProcessMesh(mesh_arr, dim_names)
105109
set_global_mesh(global_mesh)
106110

107-
def run_test_cases(self):
111+
def run_dynamic(self):
108112
model = LlamaForCausalLMAuto(self.config)
113+
criterion = LlamaPretrainingCriterionAuto(self.config)
109114

110115
lr_scheduler = paddle.optimizer.lr.LinearWarmup(
111116
learning_rate=0.0001, warmup_steps=2, start_lr=0, end_lr=0.0001
@@ -133,7 +138,8 @@ def run_test_cases(self):
133138
for epoch_idx in range(1):
134139
for step, inputs in enumerate(train_dataloader):
135140
input_ids, labels = inputs
136-
tr_loss_step, _ = model(input_ids, labels=labels)
141+
logits = model(input_ids)
142+
tr_loss_step = criterion(logits, labels)
137143

138144
if self.gradient_accumulation_steps > 1:
139145
tr_loss_step /= self.gradient_accumulation_steps
@@ -154,6 +160,51 @@ def run_test_cases(self):
154160
if global_step // self.gradient_accumulation_steps >= 10:
155161
break
156162

163+
def run_dy2static(self):
164+
model = LlamaForCausalLMAuto(self.config)
165+
criterion = LlamaPretrainingCriterionAuto(self.config)
166+
167+
lr_scheduler = paddle.optimizer.lr.LinearWarmup(
168+
learning_rate=0.0001, warmup_steps=2, start_lr=0, end_lr=0.0001
169+
)
170+
optimizer = create_optimizer(model, lr_scheduler)
171+
optimizer = dist.shard_optimizer(optimizer)
172+
173+
train_dataset = RandomDataset(self.config.seq_length)
174+
train_sampler = BatchSampler(
175+
train_dataset,
176+
batch_size=2,
177+
shuffle=True,
178+
drop_last=True,
179+
)
180+
train_dataloader = DataLoader(
181+
train_dataset,
182+
batch_sampler=train_sampler,
183+
num_workers=0,
184+
)
185+
186+
if isinstance(optimizer, dist.auto_parallel.api._ShardOptimizer):
187+
opt = optimizer._inner_opt
188+
else:
189+
opt = optimizer
190+
191+
dist_model, dist_loader = dist.to_static(
192+
model, train_dataloader, criterion, opt
193+
)
194+
195+
dist_model.train()
196+
for step, inputs in enumerate(dist_loader()):
197+
input_ids, labels = inputs
198+
loss = dist_model(input_ids, labels)
199+
print(step, loss)
200+
201+
if step >= 10:
202+
break
203+
204+
def run_test_cases(self):
205+
self.run_dynamic()
206+
self.run_dy2static()
207+
157208

158209
if __name__ == '__main__':
159210
TestLlamaAuto().run_test_cases()

test/auto_parallel/hybrid_strategy/semi_auto_parallel_llama_model.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def forward(
251251
attn_output = outputs
252252

253253
attn_output = self.o_proj(attn_output)
254-
254+
# TODO add should be in SP region
255255
if self.config.sequence_parallel:
256256
attn_output = paddle.transpose(attn_output, [1, 0, 2])
257257
attn_output = dist.reshard(
@@ -501,7 +501,7 @@ def _prepare_decoder_attention_mask(
501501
combined_attention_mask = dist.shard_tensor(
502502
combined_attention_mask,
503503
get_mesh(),
504-
[dist.Shard(0), dist.Replicate()],
504+
[dist.Replicate(), dist.Replicate()],
505505
)
506506
expanded_attn_mask = (
507507
expanded_attn_mask & combined_attention_mask
@@ -582,7 +582,7 @@ def forward(
582582
(batch_size, seq_length)
583583
)
584584
position_ids = dist.shard_tensor(
585-
position_ids, get_mesh(), [dist.Shard(0), dist.Replicate()]
585+
position_ids, get_mesh(), [dist.Replicate(), dist.Replicate()]
586586
)
587587

588588
if self.config.sequence_parallel:
@@ -607,7 +607,7 @@ def forward(
607607
all_self_attns = () if output_attentions else None
608608
next_decoder_cache = () if use_cache else None
609609

610-
pre_ipp = 0
610+
pre_ipp = None
611611
for idx, (decoder_layer) in enumerate(self.layers):
612612
if output_hidden_states:
613613
all_hidden_states += (hidden_states,)
@@ -708,6 +708,9 @@ def __init__(self, config):
708708
)
709709

710710
def forward(self, prediction_scores, masked_lm_labels):
711+
masked_lm_labels = dist.shard_tensor(
712+
masked_lm_labels, get_mesh(-1), [dist.Shard(0), dist.Replicate()]
713+
)
711714
masked_lm_loss = self.loss_func(
712715
prediction_scores.astype("float32"), masked_lm_labels.unsqueeze(2)
713716
)
@@ -732,7 +735,7 @@ def __init__(self, config):
732735

733736
self.llama = LlamaModelAuto(config)
734737
self.lm_head = LlamaLMHeadAuto(config)
735-
self.criterion = LlamaPretrainingCriterionAuto(config)
738+
# self.criterion = LlamaPretrainingCriterionAuto(config)
736739

737740
def forward(
738741
self,
@@ -770,25 +773,27 @@ def forward(
770773

771774
hidden_states = outputs[0] # [bs, seq_len, dim]
772775

776+
# if labels is None,means we need full output, instead of tensor_parallel_output
773777
if self.config.sequence_parallel:
774778
hidden_states = dist.reshard(
775779
hidden_states, get_mesh(-1), [dist.Shard(1), dist.Replicate()]
776780
)
777781
# [S, B, H] -> [B, S, H]
778782
hidden_states = paddle.transpose(hidden_states, [1, 0, 2])
779-
# if labels is None,means we need full output, instead of tensor_parallel_output
783+
780784
logits = self.lm_head(hidden_states)
781785

782-
loss = None
783-
if labels is not None:
784-
labels.stop_gradient = True
785-
labels = dist.shard_tensor(
786-
labels, get_mesh(-1), [dist.Shard(0), dist.Replicate()]
787-
)
788-
loss = self.criterion(logits, labels)
786+
# loss = None
787+
# if labels is not None:
788+
# labels.stop_gradient = True
789+
# labels = dist.shard_tensor(
790+
# labels, get_mesh(-1), [dist.Shard(0), dist.Replicate()]
791+
# )
792+
# loss = self.criterion(logits, labels)
789793

790-
output = (logits,) + outputs[1:]
791-
return (loss,) + output if loss is not None else output
794+
# output = (logits,) + outputs[1:]
795+
# return (loss,) + output if loss is not None else output
796+
return logits
792797

793798

794799
def _expand_2d_mask(mask, dtype, tgt_length):

test/auto_parallel/hybrid_strategy/test_semi_auto_parallel_hybrid_strategy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def test_simple_net_hybrid_strategy(self):
164164
class TestSemiAutoParallelLlama3D(test_base.CommunicationTestDistBase):
165165
def setUp(self):
166166
super().setUp(num_of_devices=8, timeout=200, nnode=1)
167-
self._default_envs = {"dp": "2", "mp": "2", "pp": "2", "acc_step": "2"}
167+
self._default_envs = {"dp": "2", "mp": "2", "pp": "2", "acc_step": "1"}
168168
self._changeable_envs = {
169169
"backend": ["gpu"],
170170
"use_sp": ["true", "false"],

test/auto_parallel/semi_auto_parallel_dist_to_static_api.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ def run_test(self):
200200
# not prepared
201201
# NOTE: This use is not recommended, only for the test. In normal
202202
# use, DistModel is generated by dist.to_static.
203+
203204
dist_model._engine._has_prepared["train"] = False
204205
dist_model._engine._has_prepared["eval"] = False
205206
dist_model._engine._has_prepared["predict"] = False

0 commit comments

Comments
 (0)