Skip to content

Commit 45b837f

Browse files
committed
OpClass -> ReductionOp
1 parent 15cb34f commit 45b837f

File tree

4 files changed

+42
-42
lines changed

4 files changed

+42
-42
lines changed

mlir/include/mlir/Dialect/MPI/IR/MPI.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def MPI_OpMinloc : I32EnumAttrCase<"MPI_MINLOC", 11, "MPI_MINLOC">;
230230
def MPI_OpMaxloc : I32EnumAttrCase<"MPI_MAXLOC", 12, "MPI_MAXLOC">;
231231
def MPI_OpReplace : I32EnumAttrCase<"MPI_REPLACE", 13, "MPI_REPLACE">;
232232

233-
def MPI_OpClassEnum : I32EnumAttr<"MPI_OpClassEnum", "MPI operation class", [
233+
def MPI_ReductionOpEnum : I32EnumAttr<"MPI_ReductionOpEnum", "MPI operation class", [
234234
MPI_OpNull,
235235
MPI_OpMax,
236236
MPI_OpMin,

mlir/include/mlir/Dialect/MPI/IR/MPIOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
283283
let arguments = (
284284
ins AnyMemRef : $sendbuf,
285285
AnyMemRef : $recvbuf,
286-
MPI_OpClassEnum : $op,
286+
MPI_ReductionOpEnum : $op,
287287
MPI_Comm : $comm
288288
);
289289

mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ class MPIImplTraits {
116116
/// enum value.
117117
virtual Value getMPIOp(const Location loc,
118118
ConversionPatternRewriter &rewriter,
119-
mpi::MPI_OpClassEnum opAttr) = 0;
119+
mpi::MPI_ReductionOpEnum opAttr) = 0;
120120
};
121121

122122
//===----------------------------------------------------------------------===//
@@ -199,49 +199,49 @@ class MPICHImplTraits : public MPIImplTraits {
199199
}
200200

201201
Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter,
202-
mpi::MPI_OpClassEnum opAttr) override {
202+
mpi::MPI_ReductionOpEnum opAttr) override {
203203
int32_t op = MPI_NO_OP;
204204
switch (opAttr) {
205-
case mpi::MPI_OpClassEnum::MPI_OP_NULL:
205+
case mpi::MPI_ReductionOpEnum::MPI_OP_NULL:
206206
op = MPI_NO_OP;
207207
break;
208-
case mpi::MPI_OpClassEnum::MPI_MAX:
208+
case mpi::MPI_ReductionOpEnum::MPI_MAX:
209209
op = MPI_MAX;
210210
break;
211-
case mpi::MPI_OpClassEnum::MPI_MIN:
211+
case mpi::MPI_ReductionOpEnum::MPI_MIN:
212212
op = MPI_MIN;
213213
break;
214-
case mpi::MPI_OpClassEnum::MPI_SUM:
214+
case mpi::MPI_ReductionOpEnum::MPI_SUM:
215215
op = MPI_SUM;
216216
break;
217-
case mpi::MPI_OpClassEnum::MPI_PROD:
217+
case mpi::MPI_ReductionOpEnum::MPI_PROD:
218218
op = MPI_PROD;
219219
break;
220-
case mpi::MPI_OpClassEnum::MPI_LAND:
220+
case mpi::MPI_ReductionOpEnum::MPI_LAND:
221221
op = MPI_LAND;
222222
break;
223-
case mpi::MPI_OpClassEnum::MPI_BAND:
223+
case mpi::MPI_ReductionOpEnum::MPI_BAND:
224224
op = MPI_BAND;
225225
break;
226-
case mpi::MPI_OpClassEnum::MPI_LOR:
226+
case mpi::MPI_ReductionOpEnum::MPI_LOR:
227227
op = MPI_LOR;
228228
break;
229-
case mpi::MPI_OpClassEnum::MPI_BOR:
229+
case mpi::MPI_ReductionOpEnum::MPI_BOR:
230230
op = MPI_BOR;
231231
break;
232-
case mpi::MPI_OpClassEnum::MPI_LXOR:
232+
case mpi::MPI_ReductionOpEnum::MPI_LXOR:
233233
op = MPI_LXOR;
234234
break;
235-
case mpi::MPI_OpClassEnum::MPI_BXOR:
235+
case mpi::MPI_ReductionOpEnum::MPI_BXOR:
236236
op = MPI_BXOR;
237237
break;
238-
case mpi::MPI_OpClassEnum::MPI_MINLOC:
238+
case mpi::MPI_ReductionOpEnum::MPI_MINLOC:
239239
op = MPI_MINLOC;
240240
break;
241-
case mpi::MPI_OpClassEnum::MPI_MAXLOC:
241+
case mpi::MPI_ReductionOpEnum::MPI_MAXLOC:
242242
op = MPI_MAXLOC;
243243
break;
244-
case mpi::MPI_OpClassEnum::MPI_REPLACE:
244+
case mpi::MPI_ReductionOpEnum::MPI_REPLACE:
245245
op = MPI_REPLACE;
246246
break;
247247
}
@@ -336,49 +336,49 @@ class OMPIImplTraits : public MPIImplTraits {
336336
}
337337

338338
Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter,
339-
mpi::MPI_OpClassEnum opAttr) override {
339+
mpi::MPI_ReductionOpEnum opAttr) override {
340340
StringRef op;
341341
switch (opAttr) {
342-
case mpi::MPI_OpClassEnum::MPI_OP_NULL:
342+
case mpi::MPI_ReductionOpEnum::MPI_OP_NULL:
343343
op = "ompi_mpi_no_op";
344344
break;
345-
case mpi::MPI_OpClassEnum::MPI_MAX:
345+
case mpi::MPI_ReductionOpEnum::MPI_MAX:
346346
op = "ompi_mpi_max";
347347
break;
348-
case mpi::MPI_OpClassEnum::MPI_MIN:
348+
case mpi::MPI_ReductionOpEnum::MPI_MIN:
349349
op = "ompi_mpi_min";
350350
break;
351-
case mpi::MPI_OpClassEnum::MPI_SUM:
351+
case mpi::MPI_ReductionOpEnum::MPI_SUM:
352352
op = "ompi_mpi_sum";
353353
break;
354-
case mpi::MPI_OpClassEnum::MPI_PROD:
354+
case mpi::MPI_ReductionOpEnum::MPI_PROD:
355355
op = "ompi_mpi_prod";
356356
break;
357-
case mpi::MPI_OpClassEnum::MPI_LAND:
357+
case mpi::MPI_ReductionOpEnum::MPI_LAND:
358358
op = "ompi_mpi_land";
359359
break;
360-
case mpi::MPI_OpClassEnum::MPI_BAND:
360+
case mpi::MPI_ReductionOpEnum::MPI_BAND:
361361
op = "ompi_mpi_band";
362362
break;
363-
case mpi::MPI_OpClassEnum::MPI_LOR:
363+
case mpi::MPI_ReductionOpEnum::MPI_LOR:
364364
op = "ompi_mpi_lor";
365365
break;
366-
case mpi::MPI_OpClassEnum::MPI_BOR:
366+
case mpi::MPI_ReductionOpEnum::MPI_BOR:
367367
op = "ompi_mpi_bor";
368368
break;
369-
case mpi::MPI_OpClassEnum::MPI_LXOR:
369+
case mpi::MPI_ReductionOpEnum::MPI_LXOR:
370370
op = "ompi_mpi_lxor";
371371
break;
372-
case mpi::MPI_OpClassEnum::MPI_BXOR:
372+
case mpi::MPI_ReductionOpEnum::MPI_BXOR:
373373
op = "ompi_mpi_bxor";
374374
break;
375-
case mpi::MPI_OpClassEnum::MPI_MINLOC:
375+
case mpi::MPI_ReductionOpEnum::MPI_MINLOC:
376376
op = "ompi_mpi_minloc";
377377
break;
378-
case mpi::MPI_OpClassEnum::MPI_MAXLOC:
378+
case mpi::MPI_ReductionOpEnum::MPI_MAXLOC:
379379
op = "ompi_mpi_maxloc";
380380
break;
381-
case mpi::MPI_OpClassEnum::MPI_REPLACE:
381+
case mpi::MPI_ReductionOpEnum::MPI_REPLACE:
382382
op = "ompi_mpi_replace";
383383
break;
384384
}

mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -519,23 +519,23 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
519519
}
520520
};
521521

522-
static mpi::MPI_OpClassEnumAttr getMPIReduction(ReductionKindAttr kind) {
522+
static mpi::MPI_ReductionOpEnumAttr getMPIReductionOp(ReductionKindAttr kind) {
523523
auto ctx = kind.getContext();
524524
switch (kind.getValue()) {
525525
case ReductionKind::Sum:
526-
return mpi::MPI_OpClassEnumAttr::get(ctx, mpi::MPI_OpClassEnum::MPI_SUM);
526+
return mpi::MPI_ReductionOpEnumAttr::get(ctx, mpi::MPI_ReductionOpEnum::MPI_SUM);
527527
case ReductionKind::Product:
528-
return mpi::MPI_OpClassEnumAttr::get(ctx, mpi::MPI_OpClassEnum::MPI_PROD);
528+
return mpi::MPI_ReductionOpEnumAttr::get(ctx, mpi::MPI_ReductionOpEnum::MPI_PROD);
529529
case ReductionKind::Min:
530-
return mpi::MPI_OpClassEnumAttr::get(ctx, mpi::MPI_OpClassEnum::MPI_MIN);
530+
return mpi::MPI_ReductionOpEnumAttr::get(ctx, mpi::MPI_ReductionOpEnum::MPI_MIN);
531531
case ReductionKind::Max:
532-
return mpi::MPI_OpClassEnumAttr::get(ctx, mpi::MPI_OpClassEnum::MPI_MAX);
532+
return mpi::MPI_ReductionOpEnumAttr::get(ctx, mpi::MPI_ReductionOpEnum::MPI_MAX);
533533
case ReductionKind::BitwiseAnd:
534-
return mpi::MPI_OpClassEnumAttr::get(ctx, mpi::MPI_OpClassEnum::MPI_BAND);
534+
return mpi::MPI_ReductionOpEnumAttr::get(ctx, mpi::MPI_ReductionOpEnum::MPI_BAND);
535535
case ReductionKind::BitwiseOr:
536-
return mpi::MPI_OpClassEnumAttr::get(ctx, mpi::MPI_OpClassEnum::MPI_BOR);
536+
return mpi::MPI_ReductionOpEnumAttr::get(ctx, mpi::MPI_ReductionOpEnum::MPI_BOR);
537537
case ReductionKind::BitwiseXor:
538-
return mpi::MPI_OpClassEnumAttr::get(ctx, mpi::MPI_OpClassEnum::MPI_BXOR);
538+
return mpi::MPI_ReductionOpEnumAttr::get(ctx, mpi::MPI_ReductionOpEnum::MPI_BXOR);
539539
default:
540540
assert(false && "Unknown/unsupported reduction kind");
541541
}
@@ -626,7 +626,7 @@ struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> {
626626
// Create the MPI AllReduce operation.
627627
iBuilder.create<mpi::AllReduceOp>(
628628
TypeRange(), buffer1d, buffer1d,
629-
getMPIReduction(adaptor.getReductionAttr()), comm);
629+
getMPIReductionOp(adaptor.getReductionAttr()), comm);
630630

631631
// If the destination is a memref, cast it to a tensor
632632
if (isa<RankedTensorType>(op.getType()))

0 commit comments

Comments
 (0)