Skip to content

Commit e6559ce

Browse files
authored
fix some bugs (#7320)
1 parent 7bbfc1a commit e6559ce

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

docs/dev_guides/api_contributing_guides/auto_parallel_op_cn.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
<img src="https://raw.githubusercontent.com/PaddlePaddle/docs/develop/docs/dev_guides/api_contributing_guides/images/process_mesh_2-2.png" width="50%"/>
2020
</p>
2121

22-
我们可以使用 DimsMapping 来表示数据在 ProcessMesh 上的分布方式,DimsMapping[i] = j 表示张量的第 i 维在 ProcessMesh 的第 j 维上被切分,若 j 为 -1 则表示不切分,在该维上复制。例如张量大小为 (4, 4),process_mesh 的大小为 [2, 2],DimsMapping = [-1, 1] 表示张量的第 0 维不切分,第 1 维在 ProcessMesh 的第 1 维上切分,切分后每个卡上的张量大小为 (2, 1),等价于 Placements 表示的 [Replicate(), Shard(1)]。DimsMapping = [0, 1] 表示张量的第 0 维在 ProcessMesh 的第 0 维上切分,第 1 维在 ProcessMesh 的第 1 维上切分,切分后每个卡上的张量大小为 (1, 1),等价于 Placements 表示的 [Shard(0), Shard(1)]。下图分别展示了 DimsMapping 为 [-1, 1][0, 1] 时的张量切分情况。
22+
我们可以使用 DimsMapping 来表示数据在 ProcessMesh 上的分布方式,DimsMapping[i] = j 表示张量的第 i 维在 ProcessMesh 的第 j 维上被切分,若 j 为 -1 则表示不切分,在该维上复制。例如张量大小为 (2, 2),process_mesh 的大小为 [2, 2],DimsMapping = [-1, 1] 表示张量的第 0 维不切分,第 1 维在 ProcessMesh 的第 1 维上切分,切分后每个卡上的张量大小为 (2, 1),等价于 Placements 表示的 [Replicate(), Shard(1)]。DimsMapping = [0, 1] 表示张量的第 0 维在 ProcessMesh 的第 0 维上切分,第 1 维在 ProcessMesh 的第 1 维上切分,切分后每个卡上的张量大小为 (1, 1),等价于 Placements 表示的 [Shard(0), Shard(1)]。下图分别展示了 DimsMapping 为 [-1, 1][0, 1] 时的张量切分情况。
2323

2424
<p align="center">
2525
<img src="https://raw.githubusercontent.com/PaddlePaddle/docs/develop/docs/dev_guides/api_contributing_guides/images/DimMapping.png" width="70%"/>

docs/guides/paddle_v3_features/auto_parallel_cn.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ class MlpModel(paddle.nn.Layer):
182182
self.w1 = self.create_parameter(shape=[4096, 1024])
183183

184184
def forward(self, x):
185-
dist.shard_tensor(x, mesh, [dist.Shard(0)]) # 标记输入数据沿第 0 维切分
185+
x = dist.shard_tensor(x, mesh, [dist.Shard(0)]) # 标记输入数据沿第 0 维切分
186186
y = paddle.matmul(x, self.w0)
187187
z = paddle.matmul(y, self.w1)
188188
return z

0 commit comments

Comments
 (0)