Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 49 additions & 13 deletions docs/api/paddle/distributed/utils/global_scatter_cn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,28 @@ global_scatter
global_scatter 根据 local_count 将 x 的数据分发到 n_expert * world_size 个 expert,然后根据 global_count 接收数据。
其中 expert 是用户定义的专家网络,n_expert 是指每张卡拥有的专家网络数目,world_size 是指运行网络的显卡数目。

如下图所示,world_size 是 2,n_expert 是 2,x 的 batch_size 是 4,local_count 是[2, 0, 2, 0],0 卡的 global_count 是[2, 0, , ],
1 卡的 global_count 是[2, 0, ,](因为篇幅问题,这里只展示在 0 卡运算的数据),在 global_scatter 算子里,
local_count[i]代表向第 (i // n_expert)张卡的第 (i % n_expert)个 expert 发送 local_expert[i]个数据,
global_count[i]代表从第 (i // n_expert)张卡接收 global_count[i]个数据给本卡的 第(i % n_expert)个 expert。
如下图所示,world_size 是 2,n_expert 是 2,x 的 batch_size 是 4,local_count 是[2, 0, 2, 0],0 卡的 global_count 是 [2, 0, , ],
1 卡的 global_count 是 [2, 0, ,](因为篇幅问题,这里只展示在 0 卡运算的数据),在 global_scatter 算子里,
local_count[i] 代表向第 (i // n_expert) 张卡的第 (i % n_expert) 个 expert 发送 local_expert[i] 个数据,
global_count[i] 代表从第 (i // n_expert) 张卡接收 global_count[i] 个数据给本卡的 第(i % n_expert)个 expert。
图中的 rank0 代表第 0 张卡,rank1 代表第 1 张卡。
global_scatter 发送数据的流程如下:

local_count[0]代表从 x 里取出 2 个 batch 的数据向第 0 张卡的第 0 个 expert 发送 2 个数据;
local_count[0] 代表从 x 里取出 2 个 batch 的数据向第 0 张卡的第 0 个 expert 发送 2 个数据;

local_count[1]代表从 x 里取出 0 个 batch 的数据向第 0 张卡的第 1 个 expert 发送 0 个数据;
local_count[1] 代表从 x 里取出 0 个 batch 的数据向第 0 张卡的第 1 个 expert 发送 0 个数据;

local_count[2]代表从 x 里取出 2 个 batch 的数据向第 1 张卡的第 0 个 expert 发送 2 个数据;
local_count[2] 代表从 x 里取出 2 个 batch 的数据向第 1 张卡的第 0 个 expert 发送 2 个数据;

local_count[3]代表从 x 里取出 0 个 batch 的数据向第 1 张卡的第 1 个 expert 发送 0 个数据;
local_count[3] 代表从 x 里取出 0 个 batch 的数据向第 1 张卡的第 1 个 expert 发送 0 个数据;

所以第 0 张卡的 global_count[0]等于 2,代表从第 0 张卡接收 2 个 batch 的数据给第 0 个 expert;
所以第 0 张卡的 global_count[0] 等于 2,代表从第 0 张卡接收 2 个 batch 的数据给第 0 个 expert;

第 0 张卡的 global_count[1]等于 0,代表从第 0 张卡接收 0 个 batch 的数据给第 1 个 expert;
第 0 张卡的 global_count[1] 等于 0,代表从第 0 张卡接收 0 个 batch 的数据给第 1 个 expert;

第 1 张卡的 global_count[0]等于 2,代表从第 0 张卡接收 2 个 batch 的数据给第 0 个 expert;
第 1 张卡的 global_count[0] 等于 2,代表从第 0 张卡接收 2 个 batch 的数据给第 0 个 expert;

第 1 张卡的 global_count[1]等与 0,代表从第 0 张卡接收 0 个 batch 的数据给第 1 个 expert。
第 1 张卡的 global_count[1] 等与 0,代表从第 0 张卡接收 0 个 batch 的数据给第 1 个 expert。


.. image:: ../img/global_scatter_gather.png
Expand All @@ -52,4 +52,40 @@ Tensor,从所有 expert 接收的数据,按照每个 expert 排列。

代码示例
:::::::::
COPY-FROM: paddle.distributed.utils.global_scatter

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里不是用code-block来写代码,是COPY-FROM从英文api里面同步代码,可以学习一下其他PR之后进行修改

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

明白

.. code-block:: python

# required: distributed
import numpy as np
import paddle
from paddle.distributed import init_parallel_env

# 初始化并行环境
init_parallel_env()
n_expert = 2
world_size = 2
d_model = 2
in_feat = d_model
local_input_buf = np.array([[1, 2],[3, 4],[5, 6],[7, 8],[9, 10]], dtype=np.float32)
if paddle.distributed.ParallelEnv().local_rank == 0:
local_count = np.array([2, 1, 1, 1])
global_count = np.array([2, 1, 1, 1])
else:
local_count = np.array([1, 1, 2, 1])
global_count = np.array([1, 1, 2, 1])
local_input_buf = paddle.to_tensor(local_input_buf, dtype="float32", stop_gradient=False)
local_count = paddle.to_tensor(local_count, dtype="int64")
global_count = paddle.to_tensor(global_count, dtype="int64")
a = paddle.distributed.utils.global_scatter(local_input_buf, local_count, global_count)
a.stop_gradient = False
print(a)
# rank 0 输出: [[1, 2], [3, 4], [1, 2], [5, 6], [3, 4]]
# rank 1 输出: [[7, 8], [5, 6], [7, 8], [9, 10], [9, 10]]
# backward test
c = a * a
c.backward()
print("local_input_buf.grad: ", local_input_buf.grad)
# rank 0 输出: [[2, 4], [6, 8], [10, 12], [14, 16], [18, 20]]
# rank 1 输出: [[2, 4], [6, 8], [10, 12], [14, 16], [18, 20]]
# 喵喵喵AwA