Skip to content

Commit 778dba3

Browse files
authored
[AutoParallel] Update Doc for AutoParallel (#6695)
* refactor framework * update auto parallel * update auto parallel * code style
1 parent 283f453 commit 778dba3

File tree

4 files changed

+63
-14
lines changed

4 files changed

+63
-14
lines changed

docs/guides/06_distributed_training/auto_parallel_cn.md

Lines changed: 63 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
自动并行训练
22
=======================
33

4+
# 一、背景动机
5+
46
超大模型已经成为 AI 最重要的核心竞争力之一,随着模型规模持续快速增长,各种并行策略和关键技术相继提出,可以看出底层平台技术已呈收敛趋势,超大模型分布式训练逐渐地从『增量』过渡到『存量』竞争。如何更灵活支持各类应用场景对复杂并行策略的需求,如何帮助用户更简单进行分布式训练,如何兼容动态图灵活调试和静态图性能优势的优点等都是亟待解决挑战。
57

68
飞桨当前支持分布式训练当前有几种方式:动态图手动并行(动手)、静态图手动并行(静手)、动静统一自动并行(自动并行)几种方式。
79

8-
手动并行(包括动态图手动并行和静态图手动并行)需要直接在用户感知到分布式实现的细节,例如通信组 ``process_group``,以及在组网中添加各类并行策略相关的 API,例如张量并行的 ``ColumnParallelLinear````RowParallelLinear`` 等,对于不同的分布式并行策略,都需要调用不同的接口,相对来说比较使用起来复杂。
10+
手动并行(包括动态图手动并行和静态图手动并行)需要用户在组网时直接感知到分布式实现的细节,例如通信原语``Allreduce````Allgather``,通信组 ``process_group``,以及在组网中添加各类并行策略相关的 API,例如张量并行的 ``ColumnParallelLinear````RowParallelLinear`` 等,对于不同的分布式并行策略,都需要调用不同的接口,相对来说比较使用起来复杂。
911

1012
自动并行为了降低用户开发分布式程序的门槛,提供了对不同分布式并行策略的统一抽象,让用户可以通过 `张量切分` 的语法标记即可实现不同并行策略。用户仅需使用少量的张量切分标注,框架便能自动推导出所有张量和算子的分布式切分状态,并添加合适的通信算子。同时自动并行还支持一键动转静分布式训练,开发者可以快速实现任意混合并行策略,大幅简化了混合并行训练代码的开发过程。
1113

12-
一、自动并行相关 API
13-
--------
14+
# 二、基本概念
15+
16+
## 2.1 自动并行 API
1417

1518
根据功能,我们将自动并行支持的 API 分为标记信息、动转静、Save&Load 三类。
1619

@@ -29,12 +32,11 @@
2932
* paddle.distributed.save_state_dict:保存模型参数结构到指定路径
3033
* paddle.distributed.load_state_dict:从指定路径加载模型
3134

32-
二、分布式张量
33-
--------
35+
## 2.2 分布式张量
3436

3537
目前已有的分布式策略,数据并行、模型并行等,都是通过(1)切分输入/输出(2)切分模型参数 (3)切分计算 这三种方式,满足在多计算设备上加速训练大模型的需求。为了提供更易用的分布式接口,我们引入分布式张量这一概念,描述由多个计算设备上的局部物理张量通过既定计算共同组成的逻辑张量,用户可以通过 paddle.distributed.shard_tensor 来创建分布式张量。
3638

37-
为了描述分布式张量和计算设备之前的映射关系,我们引入 ``Placements````ProcessMesh`` 两个分布式概念``Placements`` 是由 ``Replicate````Shard````Partial`` 三种分布式标记组成的列表,长度和 ``ProcessMesh`` 的维度一致,用于表示分布式张量在对应计算设备的维度上,按照哪种分布式标记做切分,这三种分布式标记的详细描述如下:
39+
为了描述分布式张量和计算设备之前的映射关系,我们引入 ``Placements````ProcessMesh`` 两个分布式概念``ProcessMesh`` 可以理解为是用一个高维矩阵对分布集群中计算设备的抽象,比如一个 4 机 32 卡的集群可以用一个 shape=[4,8] 的 mesh 矩阵进行描述;``Placements`` 是由 ``Replicate````Shard````Partial`` 三种分布式标记组成的列表,长度和 ``ProcessMesh`` 的维度个数一致,用于表示分布式张量在对应计算设备的维度上,按照什么方式做切分,这三种分布式标记的详细描述如下:
3840

3941
* Replicate,指张量在所有计算设备上保持全量状态。
4042
* Shard(axis),指将张量沿 axis 维度做切分后,放到不同的计算设备上。
@@ -60,7 +62,10 @@ dist_tensor = dist.shard_tensor(dense_tensor, mesh, placements)
6062
```
6163
![切分状态](images/shard.png)
6264

63-
同时,为了提供 ``重切分`` 的能力,我们提供 ``paddle.distributed.reshard`` 接口,支持跨 ``ProcessMesh`` 的分布式张量转换,比如,我们可以把在[0, 1] 两个设备上状态为 ``Replicate`` 的分布式张量,转换到 [2, 3] 这两个设备上,并变成状态为 ``Shard`` 的分布式张量。
65+
## 2.3 张量重切分
66+
67+
如果我们希望改变一个分布式张量在集群中的分布式状态,需要使用``重切分`` 功能, 框架中通过``paddle.distributed.reshard``接口提供。
68+
通过重切分我们可以支持跨 ``ProcessMesh`` 的分布式张量转换,比如,我们可以把在[0, 1] 两个设备上状态为 ``Replicate`` 的分布式张量,转换到 [2, 3] 这两个设备上,并变成状态为 ``Shard`` 的分布式张量。
6469

6570
```python
6671
import paddle
@@ -80,11 +85,55 @@ dist_tensor_after_reshard = dist.reshard(dist_tensor, mesh1, placements1)
8085
```
8186
![切分状态](images/reshard.png)
8287

88+
# 三、原理简介
89+
90+
下面我们用一个简单的列子介绍自动并行框架底层的执行流程和原理。
91+
92+
在单卡逻辑视角下我们希望完成计算 C = Matmul(A, B),D = Relu(C)。
93+
假设用户将 TensorB 标记成按列切分,表示在实际分布式集群中 TensorB 被按行切分到不同的 Devices 上。将 TensorA 标记成复制,表示所有 Devices 上都有完整 TensorA 副本。
94+
95+
```python
96+
import paddle
97+
import paddle.distributed as dist
98+
99+
mesh = dist.ProcessMesh([0, 1], dim_names=['x'])
100+
dense_tensorA = paddle.to_tensor([[1,2,], [3,4]])
101+
dense_tensorB = paddle.to_tensor([[5,6], [7,8]])
102+
placementsA = [dist.Replicate()]
103+
placementsB = [dist.Shard(0)]
104+
105+
dist_tensorA = dist.shard_tensor(dense_tensorA, mesh, placementsA)
106+
dist_tensorB = dist.shard_tensor(dense_tensorB, mesh, placementsB)
107+
dist_tensorC = Matmul(dist_tensorA, dist_tensorB)
108+
dist_tensorD = relu(dist_tensorC)
109+
```
110+
<div style="text-align: center;">
111+
<img src="images/underlying1.png" alt="用户标记" style="width: 45%; height: auto; center;">
112+
<!-- ![原理简介](images/underlying1.png) -->
113+
</div>
114+
115+
接下来就会进入自动并行的第一个核心逻辑 **切分推导**
116+
当前用户标记的输入切分状态是无法被 Matmul 算子实际计算的(TensorA 的第 0 维和 TensorB 的第 1 维不匹配)。
117+
这时候自动并行框架会使用当前算子的切分推导规则(e.g. MatmulSPMD Rule),根据输入 tensors 的切分状态,推导出一套合法且性能较优的 输入-输出 张量的切分状态。
118+
在上述输入的切分状态下,框架会推导出会将 TensorA 的切分状态推导成按列切分,TensorB 保持切分状态不变,Matmul 的计算结果 TensorC 的切分状态是 Partial。
119+
因为后续的 Relu 算子是非线性的,输入不能是 Partial 状态,所以框架会根据 ReluSPMD Rule 将 TensorC 输入 Relu 前的的分布式状态推导成 Replicated。
120+
<div style="text-align: center;">
121+
<img src="images/underlying2.png" alt="切分推导" style="width: 45%; height: auto; center;">
122+
</div>
123+
124+
接下来就会进入自动并行的第二个核心逻辑 **切分转换**
125+
框架会根据 tensor 当前的切分状态(src_placement),和切分推导规则推导出的算子计算需要的切分状态(dst_placement),添加对应的通信/张量维度变换算子。
126+
根据上图的切分推导,在计算 Matmul 添加 split 算子,在计算 Relue 添加 Allreduce,将输入 tensor 转换成需要的切分状态进行实际计算。
127+
128+
<div style="text-align: center;">
129+
<img src="images/underlying3.png" alt="切分转换" style="width: 45%; height: auto; center;">
130+
</div>
131+
<!-- ![原理简介](images/underlying3.png) -->
132+
83133

84-
三、自动并行和分布式策略
85-
-------------------
134+
# 四、使用示例
86135

87-
3.1 数据并行
136+
## 4.1 数据并行
88137

89138
数据并行是深度学习领域最常用的并行方法,在此策略下将数据沿 batch 维切分成多份,每个计算资源上保存完整的模型参数并独立处理一份子数据集。用自动并行的语义,用户只需要将输入标记为沿着 batch 维做切分,不需要进行其他额外的操作。
90139

@@ -148,7 +197,7 @@ for step, inputs in enumerate(dataloader):
148197
opt.clear_grad()
149198
```
150199

151-
3.2 张量并行
200+
## 4.2 张量并行
152201

153202
张量并行是在保证数学上正确的前提下,将组网中的参数切分到不同的计算设备,达到降低单个计算设备上的显存消耗的目的。用户需要显式在组网里标记切分参数的方式。
154203

@@ -178,7 +227,7 @@ class MlpModel(paddle.nn.Layer):
178227
return z
179228
```
180229

181-
3.3 流水并行
230+
## 4.3 流水并行
182231

183232
流水并行将模型的不同层放到不同的计算设备上,达到降低单个计算设备的显存消耗的目的。流水并行需要用户显式调用 ``paddle.distributed.reshard``,将前一个流水并行层的计算结果,显式传输到当前流水并行层作为输入。
184233

@@ -207,7 +256,7 @@ class MlpModel(paddle.nn.Layer):
207256
return z
208257
```
209258

210-
3.4 3D 混合并行策略
259+
## 4.4 3D 混合并行策略
211260

212261
下面是一个完整的包含数据并行、张量并行、流水并行三种策略的示例,在 ``ProcessMesh`` 的 0 维上做数据并行,1 维上做张量并行,跨 ``mesh``上做流水并行。
213262

@@ -277,7 +326,7 @@ for step, inputs in enumerate(dataloader):
277326
opt.clear_grad()
278327
```
279328

280-
3.5 动转静机制
329+
## 4.5 动转静训练
281330

282331
动态图和静态图是框架的两种执行模式,动态图方便用户调试和开发,可以即时得到执行结果,静态图会做性能优化和调度编排,将硬件资源用到极致,为了兼备两者的优点,我们提供动转静机制,支持用户在动态图上开发调试后,转成静态图执行。
283332

53 KB
Loading
172 KB
Loading
168 KB
Loading

0 commit comments

Comments
 (0)