@@ -9,43 +9,43 @@ UtilBase
99方法
1010::::::::::::
1111all_reduce(input, mode="sum", comm_world="worker")
12- '''''''''
12+ ''''''''''''''''''''''''''''''''''''''''''''''''''
1313在指定的通信集合间进行归约操作,并将归约结果返回给集合中每个实例。
1414
1515**参数 **
1616
1717 - **input ** (list|tuple|numpy.array) – 归约操作的输入。
1818 - **mode ** (str) - 归约操作的模式,包含求和,取最大值和取最小值,默认为求和归约。
19- - **comm_world ** (str) - 归约操作的通信集合,包含:server 集合(``server ``),worker 集合(``worker ``)及所有节点集合(``all ``),默认为 worker 集合。
19+ - **comm_world ** (str) - 归约操作的通信集合,包含:server 集合 (``server ``),worker 集合 (``worker ``) 及所有节点集合 (``all ``),默认为 worker 集合。
2020
2121**返回 **
2222
23- Numpy.array|None:一个和``input``形状一致的 numpy 数组或 None。
23+ Numpy.array|None:一个和 ``input `` 形状一致的 numpy 数组或 None。
2424
2525**代码示例 **
2626
2727COPY-FROM: paddle.distributed.fleet.UtilBase.all_reduce
2828
2929barrier(comm_world="worker")
30- '''''''''
30+ ''''''''''''''''''''''''''''
3131在指定的通信集合间进行阻塞操作,以实现集合间进度同步。
3232
3333**参数 **
3434
35- - **comm_world ** (str) - 阻塞操作的通信集合,包含:server 集合(``server ``),worker 集合(``worker ``)及所有节点集合(``all ``),默认为 worker 集合。
35+ - **comm_world ** (str) - 阻塞操作的通信集合,包含:server 集合 (``server ``),worker 集合 (``worker ``) 及所有节点集合 (``all ``),默认为 worker 集合。
3636
3737**代码示例 **
3838
3939COPY-FROM: paddle.distributed.fleet.UtilBase.barrier
4040
4141all_gather(input, comm_world="worker")
42- '''''''''
42+ ''''''''''''''''''''''''''''''''''''''''
4343在指定的通信集合间进行聚合操作,并将聚合的结果返回给集合中每个实例。
4444
4545**参数 **
4646
4747 - **input ** (int|float) - 聚合操作的输入。
48- - **comm_world ** (str) - 聚合操作的通信集合,包含:server 集合(``server ``),worker 集合(``worker ``)及所有节点集合(``all ``),默认为 worker 集合。
48+ - **comm_world ** (str) - 聚合操作的通信集合,包含:server 集合 (``server ``),worker 集合 (``worker ``) 及所有节点集合 (``all ``),默认为 worker 集合。
4949
5050**返回 **
5151
@@ -56,7 +56,7 @@ all_gather(input, comm_world="worker")
5656COPY-FROM: paddle.distributed.fleet.UtilBase.all_gather
5757
5858get_file_shard(files)
59- '''''''''
59+ '''''''''''''''''''''
6060在数据并行的分布式训练中,获取属于当前训练节点的文件列表。
6161
6262.. code-block :: text
@@ -77,8 +77,7 @@ get_file_shard(files)
7777COPY-FROM: paddle.distributed.fleet.UtilBase.get_file_shard
7878
7979print_on_rank(message, rank_id)
80- '''''''''
81-
80+ '''''''''''''''''''''''''''''''''
8281在编号为 `rank_id ` 的节点上打印指定信息。
8382
8483**参数 **
0 commit comments