Skip to content

Commit 0e74bf3

Browse files
authored
[NewIR]support c_allreduce_sum/c_identity/c_embedding/c_embedding_grad (#56836)
* [NewIR]add c_allreduce_sum/c_identity/c_reduce_sum/c_embedding/c_embedding_grad * rm VLOG * rm c_identity from LegacyOpList * rm VLOG * rm c_reduce_sum
1 parent 10d60b7 commit 0e74bf3

File tree

6 files changed

+106
-1
lines changed

6 files changed

+106
-1
lines changed

paddle/fluid/ir/dialect/paddle_dialect/utils/utils.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ const std::unordered_set<std::string> LegacyOpList = {
2727
"pd.c_sync_calc_stream_",
2828
"pd.c_sync_comm_stream_",
2929
"pd.send_v2",
30-
"pd.recv_v2"};
30+
"pd.recv_v2",
31+
"pd.c_allreduce_sum",
32+
"pd.c_allreduce_sum_"};
3133

3234
enum class AttrType {
3335
UNDEFINED = 0,

paddle/phi/api/yaml/legacy_backward.yaml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,17 @@
102102
composite: batch_norm_grad(x, scale, bias, mean_out, variance_out, saved_mean, saved_variance, reserve_space, out_grad, momentum, epsilon, data_layout, is_test, use_global_stats, trainable_statistics)
103103
backward : batch_norm_double_grad
104104

105+
- backward_op : c_embedding_grad
106+
forward : c_embedding (Tensor weight, Tensor x, int64_t start_index=0) -> Tensor(out)
107+
args : (Tensor weight, Tensor x, Tensor out_grad, int64_t start_index=0)
108+
output : Tensor(weight_grad)
109+
infer_meta :
110+
func : EmbeddingGradInferMeta
111+
param : [x, weight]
112+
kernel :
113+
func : c_embedding_grad
114+
no_need_buffer : weight
115+
105116
- backward_op : cast_grad
106117
forward : cast (Tensor x, DataType dtype) -> Tensor(out)
107118
args : (Tensor x, Tensor out_grad)

paddle/phi/api/yaml/legacy_ops.yaml

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,16 @@
123123
backward : batch_norm_grad
124124
optional : reserve_space
125125

126+
- op : c_allreduce_sum
127+
args : (Tensor x, int ring_id, bool use_calc_stream, bool use_model_parallel)
128+
output : Tensor(out)
129+
infer_meta :
130+
func : AllReduceInferMeta
131+
param : [x]
132+
kernel :
133+
func : c_allreduce_sum
134+
inplace : (x -> out)
135+
126136
- op : c_broadcast
127137
args : (Tensor x, int ring_id=0, int root=0, bool use_calc_stream=false)
128138
output : Tensor(out)
@@ -142,6 +152,27 @@
142152
kernel :
143153
func : c_concat
144154

155+
- op : c_embedding
156+
args : (Tensor weight, Tensor x, int64_t start_index=0)
157+
output : Tensor(out)
158+
infer_meta :
159+
func : CEmbeddingInferMeta
160+
param : [weight, x, start_index]
161+
kernel :
162+
func : c_embedding
163+
param : [weight, x, start_index]
164+
data_type : weight
165+
backward : c_embedding_grad
166+
167+
- op : c_identity
168+
args : (Tensor x, int ring_id, bool use_calc_stream, bool use_model_parallel)
169+
output : Tensor(out)
170+
infer_meta :
171+
func : CIdentityInferMeta
172+
kernel :
173+
func : c_identity
174+
inplace : (x -> out)
175+
145176
- op : c_sync_calc_stream
146177
args : (Tensor x)
147178
output : Tensor(out)

paddle/phi/api/yaml/op_compat.yaml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,13 @@
435435
outputs :
436436
out : Out
437437

438+
- op : c_embedding
439+
backward : c_embedding_grad
440+
inputs :
441+
{weight : W, x : Ids}
442+
outputs :
443+
out : Out
444+
438445
- op : cast
439446
inputs :
440447
x : X
@@ -3032,12 +3039,24 @@
30323039
yolo_loss : GetYoloLossExpectedKernelType
30333040
yolo_loss_grad : GetYoloLossExpectedKernelType
30343041

3042+
- op: c_allreduce_sum
3043+
inputs :
3044+
x : X
3045+
outputs :
3046+
out: Out
3047+
30353048
- op: c_broadcast
30363049
inputs :
30373050
x : X
30383051
outputs :
30393052
out : Out
30403053

3054+
- op: c_identity
3055+
inputs :
3056+
x : X
3057+
outputs :
3058+
out: Out
3059+
30413060
- op: c_sync_calc_stream
30423061
inputs :
30433062
x : X

paddle/phi/infermeta/binary.cc

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1274,6 +1274,43 @@ void EmbeddingInferMeta(const MetaTensor& x,
12741274
out->share_lod(x);
12751275
}
12761276

1277+
void CEmbeddingInferMeta(const MetaTensor& weight,
1278+
const MetaTensor& x,
1279+
int64_t start_index,
1280+
MetaTensor* out) {
1281+
const auto& table_dims = weight.dims();
1282+
const auto& ids_dims = x.dims();
1283+
int ids_rank = ids_dims.size();
1284+
1285+
VLOG(5) << "ids rank is " << ids_rank << std::endl;
1286+
PADDLE_ENFORCE_EQ(
1287+
table_dims.size(),
1288+
2,
1289+
phi::errors::InvalidArgument(
1290+
"ShapeError: The dimensions of the 'c_embedding' must be 2. "
1291+
"But received c_embedding's dimensions = %d, "
1292+
"c_embedding's shape = [%s].",
1293+
table_dims.size(),
1294+
table_dims));
1295+
1296+
auto output_dims = phi::vectorize(ids_dims);
1297+
output_dims.push_back(table_dims[1]);
1298+
out->set_dims(phi::make_ddim(output_dims));
1299+
out->set_dtype(weight.dtype());
1300+
out->share_lod(x);
1301+
1302+
const auto height = table_dims[0];
1303+
const auto width = table_dims[1];
1304+
PADDLE_ENFORCE_EQ(
1305+
(height > 0 && width > 0 && start_index >= 0),
1306+
true,
1307+
phi::errors::InvalidArgument(
1308+
"height:%ld width:%ld start_index:%ld must not have negative values",
1309+
height,
1310+
width,
1311+
start_index));
1312+
}
1313+
12771314
void ExpandAsInferMeta(const MetaTensor& x,
12781315
const MetaTensor& y,
12791316
const std::vector<int>& target_shape,

paddle/phi/infermeta/binary.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,11 @@ void EmbeddingInferMeta(const MetaTensor& x,
211211
int64_t padding_idx,
212212
MetaTensor* out);
213213

214+
void CEmbeddingInferMeta(const MetaTensor& weight,
215+
const MetaTensor& x,
216+
int64_t start_index,
217+
MetaTensor* out);
218+
214219
void ExpandAsInferMeta(const MetaTensor& x,
215220
const MetaTensor& y,
216221
const std::vector<int>& target_shape,

0 commit comments

Comments
 (0)