Skip to content

Commit 99a281b

Browse files
authored
[Auto Parallel] Add Strategy api for configuring the distributed training with static graph (#59862)
* move strategy to api.py * add Strategy api * fix sample code * add detailed comments for the configs in dist.Strategy * add an error case in unit test * add the unit test to CMakeLists
1 parent a339e8b commit 99a281b

File tree

8 files changed

+372
-24
lines changed

8 files changed

+372
-24
lines changed

python/paddle/distributed/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
shard_layer,
8686
shard_optimizer,
8787
to_static,
88+
Strategy,
8889
)
8990

9091
from .fleet import BoxPSDataset # noqa: F401
@@ -165,4 +166,5 @@
165166
"load_state_dict",
166167
"shard_optimizer",
167168
"to_static",
169+
"Strategy",
168170
]

python/paddle/distributed/auto_parallel/api.py

Lines changed: 249 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import copy
1415
from collections import defaultdict
1516
from typing import Callable
1617

@@ -23,7 +24,7 @@
2324
Variable,
2425
default_main_program,
2526
)
26-
from paddle.distributed.auto_parallel import Engine
27+
from paddle.distributed.auto_parallel import Engine, strategy as auto_strategy
2728
from paddle.distributed.auto_parallel.interface import (
2829
shard_tensor as shard_tensor_static,
2930
)
@@ -108,9 +109,11 @@ def sharding_specs(self):
108109

109110
class DistModel:
110111
"""
111-
DistModel is a wrapper of the network model for the use of static mode
112-
auto parallel. DistModel contains the distributed Graph of the model and
113-
offers the APIs for training, evaluation and prediction.
112+
DistModel is generated by ``paddle.distributed.to_static``. It contains the
113+
static graph converted from a ``paddle.nn.layer`` whose parameters are
114+
distributed tensors (constructed from ``paddle.distributed.shard_tensor``),
115+
and provides the APIs for training, evaluation and prediction with the
116+
static graph.
114117
115118
Please first set the DistModel to "train", "eval" or "predict" mode and
116119
then use the __call__ method for training, evaluation and prediction
@@ -127,8 +130,8 @@ class DistModel:
127130
to "eval" mode in default. If loss and optimizer are both None, DistModel
128131
will be set to "predict" mode in default.
129132
130-
DistModel is generated by ``paddle.distributed.to_static``, for more details
131-
of the usage, please refer to the sample code in ``paddle.distributed.to_static``.
133+
For more details of the usage, please refer to the sample code in
134+
``paddle.distributed.to_static``.
132135
"""
133136

134137
def __init__(
@@ -141,8 +144,9 @@ def __init__(
141144
metrics=None,
142145
):
143146
self._feed_name_list = []
147+
self._inner_strategy = self.__convert_strategy(strategy)
144148
self._engine = Engine(
145-
layer, loss, optimizer, metrics, strategy=strategy
149+
layer, loss, optimizer, metrics, strategy=self._inner_strategy
146150
)
147151
self._mode = None
148152
self._feed_name_list = {}
@@ -271,6 +275,27 @@ def _make_feeds(self, data_list):
271275
)
272276
return dict(zip(feed_name_list, data_list))
273277

278+
def __convert_strategy(self, strategy):
279+
import copy
280+
281+
if strategy is None:
282+
return None
283+
inner_strategy = auto_strategy.Strategy()
284+
inner_strategy.fused_passes.enable = strategy.fused_passes.enable
285+
if strategy.fused_passes.gemm_epilogue is True:
286+
inner_strategy.fused_passes.fused_passes_list.append(
287+
"fused_gemm_epilogue_pass"
288+
)
289+
if strategy.fused_passes.dropout_add is True:
290+
inner_strategy.fused_passes.fused_passes_list.append(
291+
"fused_dropout_add_pass"
292+
)
293+
294+
inner_strategy.sharding = copy.deepcopy(strategy.sharding)
295+
inner_strategy.gradient_merge = copy.deepcopy(strategy.gradient_merge)
296+
inner_strategy.pipeline = copy.deepcopy(strategy.pipeline)
297+
return inner_strategy
298+
274299
def __call__(self, *args):
275300
if self._mode is None:
276301
raise ValueError("Please call train()/eval()/predict() first.")
@@ -298,6 +323,209 @@ def __call__(self, *args):
298323

299324

300325
# Part2: DistTensor construction related APIs
326+
327+
328+
class FusePasses:
329+
"""
330+
A helper class for users to configure the fuse passes.
331+
"""
332+
333+
def __init__(self, config_dict=None):
334+
self.enable = False
335+
self.gemm_epilogue = False
336+
self.dropout_add = False
337+
if config_dict is not None:
338+
for key, value in config_dict.items():
339+
if hasattr(self, key):
340+
setattr(self, key, value)
341+
else:
342+
raise ValueError(f"Unknown fuse pass {key}")
343+
344+
345+
class Strategy(auto_strategy.BaseConfig):
346+
"""
347+
The `Strategy` object is used to configure the parallelization
348+
and optimization strategies for static graph. Currently contains
349+
configuring ``sharding``, ``fused_passes``, ``gradient_merge``
350+
and ``pipline``. More strategies will be supported in the future.
351+
352+
``sharding`` is used to cnofigure the sharding states of the optimizer,
353+
for saving the GPU memory.
354+
355+
``fused_passes`` is used to configure the fusion of the computation in
356+
the model.
357+
358+
``gradient_merge`` is used to configure the gradient merge strategy in
359+
training.
360+
361+
``pipeline`` is used to configure the pipeline parallelism strategy.
362+
363+
Args:
364+
config (dict|None, optional): If ``config`` is None, the default
365+
configurations will be set. If it is a dict, the itmes inside
366+
the dict will be used to set the configurations, the others remain
367+
the default values.
368+
369+
Examples:
370+
.. code-block:: python
371+
372+
>>> import paddle
373+
>>> import paddle.distributed as dist
374+
375+
>>> strategy = dist.Strategy()
376+
377+
>>> strategy.sharding.enable = True
378+
>>> strategy.sharding.stage = 2
379+
>>> strategy.sharding.degree = 2
380+
381+
>>> strategy.gradient_merge.enable = True
382+
>>> strategy.gradient_merge.k_steps = 2
383+
>>> strategy.gradient_merge.avg = False
384+
385+
>>> strategy.pipeline.enable = True
386+
>>> strategy.pipeline.schedule_mode = "1F1B" # default is "1F1B"
387+
>>> strategy.pipeline.micro_batch_size = 2
388+
"""
389+
390+
def __init__(self, config=None):
391+
if config is not None:
392+
if isinstance(config, dict):
393+
self._config_dict = copy.deepcopy(config)
394+
else:
395+
raise ValueError(
396+
f"Expected a dictionary. But received: {config}"
397+
)
398+
else:
399+
self._config_dict = {}
400+
401+
category = auto_strategy.constants.BASE
402+
super().__init__(category, self._config_dict)
403+
404+
config_dict = self._config_dict.get(
405+
auto_strategy.constants.SHARDING, None
406+
)
407+
self._sharding = auto_strategy.ShardingConfig(config_dict)
408+
409+
config_dict = self._config_dict.get(
410+
auto_strategy.constants.GRADIENT_MERGE, None
411+
)
412+
self._gradient_merge = auto_strategy.GradientMergeConfig(config_dict)
413+
414+
config_dict = self._config_dict.get(
415+
auto_strategy.constants.PIPELINE, None
416+
)
417+
self._pipeline = auto_strategy.PipelineConfig(config_dict)
418+
419+
config_dict = self._config_dict.get(
420+
auto_strategy.constants.FUSED_PASSES, None
421+
)
422+
self._fused_passes = FusePasses(config_dict)
423+
424+
@property
425+
def sharding(self):
426+
"""
427+
``sharding`` is used to cnofigure the sharding states of the optimizer,
428+
containing following configs:
429+
430+
``enable`` (bool): whether to enable sharding. Default: False.
431+
432+
``stage`` (int): can be set to 1, 2 or 3. 1 indicates the optimizer states segmentation,
433+
2 indicates optimizer states and gradient segmentation, 3 indicates the segmentation
434+
of optimizer states, gradient and parameters. Default: 1.
435+
436+
``degree`` (int): the number of segmentation pieces. Default: 8.
437+
438+
Examples:
439+
.. code-block:: python
440+
>>> import paddle
441+
>>> import paddle.distributed as dist
442+
443+
>>> strategy = dist.Strategy()
444+
445+
>>> strategy.sharding.enable = True
446+
>>> strategy.sharding.stage = 2
447+
>>> strategy.sharding.degree = 2
448+
"""
449+
return self._sharding
450+
451+
@property
452+
def gradient_merge(self):
453+
"""
454+
``gradient_merge`` is used to configure the gradient merge strategy in
455+
training, containing following configs:
456+
457+
``enable`` (bool): whether to enable gradient merge. Default: False.
458+
459+
``k_steps`` (int): the number of steps for merging gradients. Default: 1.
460+
461+
``avg`` (bool): whether to average the gradients of each step. Default: True.
462+
463+
Examples:
464+
.. code-block:: python
465+
>>> import paddle
466+
>>> import paddle.distributed as dist
467+
468+
>>> strategy = dist.Strategy()
469+
470+
>>> strategy.gradient_merge.enable = True
471+
>>> strategy.gradient_merge.k_steps = 2
472+
>>> strategy.gradient_merge.avg = True
473+
"""
474+
return self._gradient_merge
475+
476+
@property
477+
def fused_passes(self):
478+
"""
479+
``fused_passes`` is used to configure the fusion of the computation in
480+
the model, containing following configs:
481+
482+
``enable`` (bool): whether to enable fused passes. Default: False.
483+
484+
``gemm_epilogue`` (bool): whether to fuse ``matmul`` and ``add`` computation
485+
in the ``Linear`` layer. Default: False
486+
487+
"dropout_add" (bool): whether to fuse ``dropout`` and ``add`` computation. Default: False.
488+
489+
Examples:
490+
.. code-block:: python
491+
>>> import paddle
492+
>>> import paddle.distributed as dist
493+
494+
>>> strategy = dist.Strategy()
495+
496+
>>> strategy.fused_passes.enable = True
497+
>>> strategy.fused_passes.gemm_spilogue = True
498+
>>> strategy.fused_passes.dropout_add = True
499+
"""
500+
return self._fused_passes
501+
502+
@property
503+
def pipeline(self):
504+
"""
505+
``pipeline`` is used to configure the pipeline parallelism in training,
506+
containing following configs:
507+
508+
``enable`` (bool): whether to enable pipeline parallelism. Default: False.
509+
510+
``schedule_mode`` (str): the scheduling mode of pipeline parallelism. Default: "1F1B".
511+
512+
``micro_batch_size`` (int): the size of each micro-batch inside a mini-batch. Default: 1.
513+
514+
``accumulate_steps`` (int): number of steps for accumulating. Default: 1.
515+
516+
Examples:
517+
.. code-block:: python
518+
>>> import paddle
519+
>>> import paddle.distributed as dist
520+
521+
>>> strategy = dist.Strategy()
522+
523+
>>> strategy.pipeline.enable = True
524+
>>> strategy.pipeline.micro_batch_size = 2
525+
"""
526+
return self._pipeline
527+
528+
301529
def to_static(
302530
layer: paddle.nn.Layer,
303531
loader=None,
@@ -306,29 +534,30 @@ def to_static(
306534
strategy=None,
307535
):
308536
"""
309-
Converts the model and data loader used in dygraph auto-parallelism to
310-
that in static mode auto-parallelism. to_static returns a DistModel
311-
instance that provides APIs and a DistributedDataLoader to generate data
312-
for static mode auto-parallel training, evaluation and prediction.
537+
Converts the ``layer`` with distributed tensor (constructed from
538+
``paddle.distributed.shard_tensor``) to a static graph. to_static
539+
returns a DistModel instance containing the static graph for
540+
distributed training, evaluation and prediction, and an object of
541+
DistributedDataLoader to generate data.
313542
314543
Args:
315-
layer(paddle.nn.Layer): The layer in dygraph model, the parameters
316-
or its inputs can be sharded.
317-
loader(paddle.io.DataLoader): The data loader used in dygraph model,
318-
used to generate Distributed Dataloader for static auto parallel.
544+
layer(paddle.nn.Layer): The layer in dygraph mode, the parameters
545+
or its inputs can be distributed tensors.
546+
loader(paddle.io.DataLoader): The data loader used in dygraph mode,
547+
used to generate DistributedDataloader.
319548
loss(Loss|Callable|None, optional): The loss function for training
320549
or evaluating the model. Can be a `paddle.nn.Layer` instance or
321550
any callable function. Default: None.
322551
optimizer(paddle.optimizer.Optimizer|None, optional): The optimizer
323552
for training. Default: None.
324-
strategy(Strategy|None, optional): Configs for parallel strategies
325-
(e.g. data parallel, hybrid parallel etc.) and optimization
326-
settings (e.g. mixed-precision). Default: None.
553+
strategy(paddle.distributed.Strategy|None, optional): Configs for
554+
parallel strategies and optimization settings (e.g. sharding,
555+
pipeline parallelism). Default: None.
327556
328557
Returns:
329558
DistModel: A DistModel tha contains corresponding computational graph
330-
for the input layer and provides APIs for training, evaluation and
331-
prediction.
559+
for the input ``layer`` and provides APIs for training, evaluation
560+
and prediction.
332561
DistributedDataLoader: An optimized data loader that can be used
333562
to generate data.
334563

python/paddle/distributed/auto_parallel/strategy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def __init__(self, config_dict=None):
156156

157157
class Strategy(BaseConfig):
158158
"""
159-
The `Strategy` object is used to configure the parallelization and optimization behaviors.
159+
The `Strategy` object is used to configure the parallelization and optimization for static graph.
160160
161161
Args:
162162
config (dict|string, optional): If this is None, the default configurations will used.

test/auto_parallel/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
269269
py_test_modules(test_cost_interface MODULES test_cost_interface)
270270
py_test_modules(test_auto_conditional_block MODULES
271271
test_auto_conditional_block)
272+
py_test_modules(test_strategy_api MODULES test_strategy_api)
272273
# End of unittests WITH single card WITHOUT timeout
273274

274275
endif()

test/auto_parallel/hybrid_strategy/semi_auto_llama.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,8 +188,15 @@ def run_dy2static(self):
188188
else:
189189
opt = optimizer
190190

191+
strategy = None
192+
if self.gradient_accumulation_steps > 1:
193+
strategy = dist.Strategy()
194+
strategy.pipeline.accumulate_steps = (
195+
self.gradient_accumulation_steps
196+
)
197+
191198
dist_model, dist_loader = dist.to_static(
192-
model, train_dataloader, criterion, opt
199+
model, train_dataloader, criterion, opt, strategy=strategy
193200
)
194201

195202
dist_model.train()

test/auto_parallel/hybrid_strategy/test_semi_auto_parallel_hybrid_strategy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def test_simple_net_hybrid_strategy(self):
212212
class TestSemiAutoParallelLlama3D(test_base.CommunicationTestDistBase):
213213
def setUp(self):
214214
super().setUp(num_of_devices=8, timeout=200, nnode=1)
215-
self._default_envs = {"dp": "2", "mp": "2", "pp": "2", "acc_step": "1"}
215+
self._default_envs = {"dp": "2", "mp": "2", "pp": "2", "acc_step": "2"}
216216
self._changeable_envs = {
217217
"backend": ["gpu"],
218218
"use_sp": ["true", "false"],

0 commit comments

Comments
 (0)