@@ -77,11 +77,11 @@ ShardableAxesSignature CreateDefaultSignature(pir::Operation* op) {
77
77
ShardableAxesSignature result = ShardableAxesSignature ();
78
78
for (int i = 0 ; i < op->num_operands (); ++i) {
79
79
result.inputs .emplace_back (
80
- CreateNewNamesWithRank (GetCompitableRank (op->operand_source (i))));
80
+ CreateNewNamesWithRank (GetCompatibleRank (op->operand_source (i))));
81
81
}
82
82
for (int i = 0 ; i < op->num_results (); ++i) {
83
83
result.outputs .emplace_back (
84
- CreateNewNamesWithRank (GetCompitableRank (op->result (i))));
84
+ CreateNewNamesWithRank (GetCompatibleRank (op->result (i))));
85
85
}
86
86
return result;
87
87
}
@@ -109,7 +109,7 @@ ShardableAxesSignature CreateSignatureForReduce(pir::Operation* reduce_op) {
109
109
1 ,
110
110
::common::errors::PreconditionNotMet (
111
111
" Required reduce_op->num_results() shall be equal 1." ));
112
- const size_t input_rank = GetCompitableRank (reduce_op->operand_source (0 ));
112
+ const size_t input_rank = GetCompatibleRank (reduce_op->operand_source (0 ));
113
113
auto input_axes = CreateNewNamesWithRank (input_rank);
114
114
115
115
const std::vector<int64_t > reduce_axis_idx = GetReduceAxisIdx (reduce_op);
@@ -152,20 +152,20 @@ ShardableAxesSignature CreateSignatureForReduce(pir::Operation* reduce_op) {
152
152
ShardableAxesSignature CreateSignatureForElementWise (pir::Operation* op) {
153
153
ShardableAxesSignature result = ShardableAxesSignature ();
154
154
155
- int64_t rank = GetCompitableRank (op->result (0 ));
155
+ int64_t rank = GetCompatibleRank (op->result (0 ));
156
156
auto same_axes = CreateNewNamesWithRank (rank);
157
157
158
158
for (int i = 0 ; i < op->num_operands (); ++i) {
159
159
PADDLE_ENFORCE_EQ (rank,
160
- GetCompitableRank (op->operand_source (i)),
160
+ GetCompatibleRank (op->operand_source (i)),
161
161
::common::errors::PreconditionNotMet (
162
162
" Required all inputs rank shall be equal output in "
163
163
" elementwise op." ));
164
164
result.inputs .emplace_back (same_axes);
165
165
}
166
166
for (int i = 0 ; i < op->num_results (); ++i) {
167
167
PADDLE_ENFORCE_EQ (rank,
168
- GetCompitableRank (op->result (i)),
168
+ GetCompatibleRank (op->result (i)),
169
169
::common::errors::PreconditionNotMet (
170
170
" Required all outputs rank shall be equal each other "
171
171
" in elementwise op." ));
@@ -188,7 +188,7 @@ ShardableAxesSignature CreateSignatureForTranspose(pir::Operation* op) {
188
188
" Required transpose_op->num_results() shall be equal 1." ));
189
189
190
190
const auto input_axes =
191
- CreateNewNamesWithRank (GetCompitableRank (op->operand_source (0 )));
191
+ CreateNewNamesWithRank (GetCompatibleRank (op->operand_source (0 )));
192
192
193
193
std::vector<int32_t > perm =
194
194
GetInt32ArrayAttributeData (op->attributes ().at (" perm" ));
@@ -224,7 +224,7 @@ ShardableAxesSignature CreateSignatureForSlice(
224
224
" Required slice_op->num_results() shall be equal 1." ));
225
225
226
226
const auto input_axes =
227
- CreateNewNamesWithRank (GetCompitableRank (op->operand_source (0 )));
227
+ CreateNewNamesWithRank (GetCompatibleRank (op->operand_source (0 )));
228
228
229
229
const auto [slice_axis, keepdim] = GetSliceAxis (op);
230
230
const auto output_axes = [&]() -> decltype (auto ) {
@@ -266,8 +266,8 @@ ShardableAxesSignature CreateSignatureForBroadcast(
266
266
" Required broad_cast_value is not empty." ));
267
267
268
268
const auto & [input_value, output_value] = broad_cast_value.value ();
269
- const int input_rank = GetCompitableRank (input_value);
270
- const int output_rank = GetCompitableRank (output_value);
269
+ const int input_rank = GetCompatibleRank (input_value);
270
+ const int output_rank = GetCompatibleRank (output_value);
271
271
PADDLE_ENFORCE_GE (
272
272
output_rank,
273
273
input_rank,
@@ -278,7 +278,7 @@ ShardableAxesSignature CreateSignatureForBroadcast(
278
278
// output.
279
279
for (int i = 0 ; i < op->num_operands (); ++i) {
280
280
result.inputs .emplace_back (
281
- CreateNewNamesWithRank (GetCompitableRank (op->operand_source (i))));
281
+ CreateNewNamesWithRank (GetCompatibleRank (op->operand_source (i))));
282
282
}
283
283
284
284
// Create output axes. Compare axis one by one, from back to front.
@@ -309,8 +309,8 @@ ShardableAxesSignature CreateSignatureForReshape(
309
309
pir::ShapeConstraintIRAnalysis* shape_analysis) {
310
310
const auto input_value = op->operand_source (0 );
311
311
const auto output_value = op->result (0 );
312
- const auto input_rank = GetCompitableRank (op->operand_source (0 ));
313
- const auto output_rank = GetCompitableRank (op->result (0 ));
312
+ const auto input_rank = GetCompatibleRank (op->operand_source (0 ));
313
+ const auto output_rank = GetCompatibleRank (op->result (0 ));
314
314
const auto in_shape = GetDimExprsFromValue (input_value);
315
315
const auto out_shape = GetDimExprsFromValue (output_value);
316
316
@@ -320,7 +320,7 @@ ShardableAxesSignature CreateSignatureForReshape(
320
320
321
321
if (op->name () == " pd_op.reshape" && op->num_operands () == 2 ) {
322
322
result.inputs .emplace_back (
323
- CreateNewNamesWithRank (GetCompitableRank (op->operand_source (1 ))));
323
+ CreateNewNamesWithRank (GetCompatibleRank (op->operand_source (1 ))));
324
324
}
325
325
326
326
if (GetRank (input_value) == 0 || GetRank (output_value) == 0 ) {
@@ -387,7 +387,7 @@ ShardableAxesSignature CreateSignatureForReshape(
387
387
388
388
ShardableAxesSignature CreateSignatureForConcat (
389
389
pir::Operation* op, ShardableAxesInfoManager* axes_manager) {
390
- size_t rank = GetCompitableRank (op->result (0 ));
390
+ size_t rank = GetCompatibleRank (op->result (0 ));
391
391
const auto same_axes = CreateNewNamesWithRank (rank - 1 );
392
392
393
393
const auto axis_attr =
@@ -406,7 +406,7 @@ ShardableAxesSignature CreateSignatureForConcat(
406
406
ShardableAxesSignature result = ShardableAxesSignature ();
407
407
for (int i = 0 ; i < op->num_operands (); ++i) {
408
408
PADDLE_ENFORCE_EQ (rank,
409
- GetCompitableRank (op->operand_source (i)),
409
+ GetCompatibleRank (op->operand_source (i)),
410
410
::common::errors::PreconditionNotMet (
411
411
" Required all inputs rank shall be equal output in "
412
412
" concat op." ));
0 commit comments