Skip to content

Commit 201480d

Browse files
authored
fix bug in c_split (#56917)
1 parent 79bfb18 commit 201480d

File tree

1 file changed

+33
-13
lines changed
  • python/paddle/distributed/fleet/layers/mpu

1 file changed

+33
-13
lines changed

python/paddle/distributed/fleet/layers/mpu/mp_ops.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,38 @@ def backward(ctx, dy):
4848
return dy
4949

5050

51+
class c_split_eager(PyLayer):
52+
@staticmethod
53+
def forward(ctx, tensor, group, rank, nranks):
54+
ctx.group = group
55+
ctx.nranks = nranks
56+
return _legacy_C_ops.c_split(
57+
tensor,
58+
'use_calc_stream',
59+
True,
60+
'ring_id',
61+
group.id,
62+
'rank',
63+
rank,
64+
'nranks',
65+
nranks,
66+
'use_model_parallel',
67+
True,
68+
)
69+
70+
@staticmethod
71+
def backward(ctx, dy):
72+
group = ctx.group
73+
out_shape = dy.shape
74+
out_shape[0] = out_shape[0] * ctx.nranks
75+
out = paddle.empty(out_shape, dtype=dy.dtype)
76+
group.process_group.all_gather_into_tensor_on_calc_stream(
77+
out,
78+
dy,
79+
)
80+
return out
81+
82+
5183
def _c_identity(tensor, group=None, skip_c_identity_dynamic=False):
5284
"""
5385
Return a copy of the tensor, mainly used with model parallel.
@@ -179,19 +211,7 @@ def _c_split(tensor, group=None):
179211
)
180212

181213
if in_dynamic_mode():
182-
return _legacy_C_ops.c_split(
183-
tensor,
184-
'use_calc_stream',
185-
True,
186-
'ring_id',
187-
ring_id,
188-
'rank',
189-
rank,
190-
'nranks',
191-
nranks,
192-
'use_model_parallel',
193-
True,
194-
)
214+
return c_split_eager.apply(tensor, group, rank, nranks)
195215
else:
196216
op_type = 'c_split'
197217
helper = LayerHelper(op_type, **locals())

0 commit comments

Comments
 (0)