Skip to content

Commit e723351

Browse files
authored
[Arith][MemRef] add AtomicRMWKind::xori to enum (#151701)
Add missing xor AtomicRMWKind enum in arith. Also add support for xor to memref.atomic_rmw so the change can be tested. This does NOT add it for all users of the enum (e.g. Affine, Vector)
1 parent 9573124 commit e723351

File tree

5 files changed

+29
-19
lines changed

5 files changed

+29
-19
lines changed

mlir/include/mlir/Dialect/Arith/IR/ArithBase.td

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -76,27 +76,29 @@ def Arith_CmpIPredicateAttr : I64EnumAttr<
7676

7777
def ATOMIC_RMW_KIND_ADDF : I64EnumAttrCase<"addf", 0>;
7878
def ATOMIC_RMW_KIND_ADDI : I64EnumAttrCase<"addi", 1>;
79-
def ATOMIC_RMW_KIND_ASSIGN : I64EnumAttrCase<"assign", 2>;
80-
def ATOMIC_RMW_KIND_MAXIMUMF : I64EnumAttrCase<"maximumf", 3>;
81-
def ATOMIC_RMW_KIND_MAXS : I64EnumAttrCase<"maxs", 4>;
82-
def ATOMIC_RMW_KIND_MAXU : I64EnumAttrCase<"maxu", 5>;
83-
def ATOMIC_RMW_KIND_MINIMUMF : I64EnumAttrCase<"minimumf", 6>;
84-
def ATOMIC_RMW_KIND_MINS : I64EnumAttrCase<"mins", 7>;
85-
def ATOMIC_RMW_KIND_MINU : I64EnumAttrCase<"minu", 8>;
86-
def ATOMIC_RMW_KIND_MULF : I64EnumAttrCase<"mulf", 9>;
87-
def ATOMIC_RMW_KIND_MULI : I64EnumAttrCase<"muli", 10>;
88-
def ATOMIC_RMW_KIND_ORI : I64EnumAttrCase<"ori", 11>;
89-
def ATOMIC_RMW_KIND_ANDI : I64EnumAttrCase<"andi", 12>;
90-
def ATOMIC_RMW_KIND_MAXNUMF : I64EnumAttrCase<"maxnumf", 13>;
91-
def ATOMIC_RMW_KIND_MINNUMF : I64EnumAttrCase<"minnumf", 14>;
79+
def ATOMIC_RMW_KIND_ANDI : I64EnumAttrCase<"andi", 2>;
80+
def ATOMIC_RMW_KIND_ASSIGN : I64EnumAttrCase<"assign", 3>;
81+
def ATOMIC_RMW_KIND_MAXIMUMF : I64EnumAttrCase<"maximumf", 4>;
82+
def ATOMIC_RMW_KIND_MAXNUMF : I64EnumAttrCase<"maxnumf", 5>;
83+
def ATOMIC_RMW_KIND_MAXS : I64EnumAttrCase<"maxs", 6>;
84+
def ATOMIC_RMW_KIND_MAXU : I64EnumAttrCase<"maxu", 7>;
85+
def ATOMIC_RMW_KIND_MINIMUMF : I64EnumAttrCase<"minimumf", 8>;
86+
def ATOMIC_RMW_KIND_MINNUMF : I64EnumAttrCase<"minnumf", 9>;
87+
def ATOMIC_RMW_KIND_MINS : I64EnumAttrCase<"mins", 10>;
88+
def ATOMIC_RMW_KIND_MINU : I64EnumAttrCase<"minu", 11>;
89+
def ATOMIC_RMW_KIND_MULF : I64EnumAttrCase<"mulf", 12>;
90+
def ATOMIC_RMW_KIND_MULI : I64EnumAttrCase<"muli", 13>;
91+
def ATOMIC_RMW_KIND_ORI : I64EnumAttrCase<"ori", 14>;
92+
def ATOMIC_RMW_KIND_XORI : I64EnumAttrCase<"xori", 15>;
9293

9394
def AtomicRMWKindAttr : I64EnumAttr<
9495
"AtomicRMWKind", "",
95-
[ATOMIC_RMW_KIND_ADDF, ATOMIC_RMW_KIND_ADDI, ATOMIC_RMW_KIND_ASSIGN,
96-
ATOMIC_RMW_KIND_MAXIMUMF, ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU,
97-
ATOMIC_RMW_KIND_MINIMUMF, ATOMIC_RMW_KIND_MINS, ATOMIC_RMW_KIND_MINU,
96+
[ATOMIC_RMW_KIND_ADDF, ATOMIC_RMW_KIND_ADDI, ATOMIC_RMW_KIND_ANDI,
97+
ATOMIC_RMW_KIND_ASSIGN, ATOMIC_RMW_KIND_MAXIMUMF, ATOMIC_RMW_KIND_MAXNUMF,
98+
ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU, ATOMIC_RMW_KIND_MINIMUMF,
99+
ATOMIC_RMW_KIND_MINNUMF, ATOMIC_RMW_KIND_MINS, ATOMIC_RMW_KIND_MINU,
98100
ATOMIC_RMW_KIND_MULF, ATOMIC_RMW_KIND_MULI, ATOMIC_RMW_KIND_ORI,
99-
ATOMIC_RMW_KIND_ANDI, ATOMIC_RMW_KIND_MAXNUMF, ATOMIC_RMW_KIND_MINNUMF]> {
101+
ATOMIC_RMW_KIND_XORI]> {
100102
let cppNamespace = "::mlir::arith";
101103
}
102104

mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1872,6 +1872,8 @@ matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
18721872
return LLVM::AtomicBinOp::umin;
18731873
case arith::AtomicRMWKind::ori:
18741874
return LLVM::AtomicBinOp::_or;
1875+
case arith::AtomicRMWKind::xori:
1876+
return LLVM::AtomicBinOp::_xor;
18751877
case arith::AtomicRMWKind::andi:
18761878
return LLVM::AtomicBinOp::_and;
18771879
default:

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2678,6 +2678,7 @@ TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
26782678
case AtomicRMWKind::addi:
26792679
case AtomicRMWKind::maxu:
26802680
case AtomicRMWKind::ori:
2681+
case AtomicRMWKind::xori:
26812682
return builder.getZeroAttr(resultType);
26822683
case AtomicRMWKind::andi:
26832684
return builder.getIntegerAttr(
@@ -2736,7 +2737,7 @@ std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
27362737
// Integer operations.
27372738
.Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })
27382739
.Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })
2739-
.Case([](arith::XOrIOp op) { return AtomicRMWKind::ori; })
2740+
.Case([](arith::XOrIOp op) { return AtomicRMWKind::xori; })
27402741
.Case([](arith::AndIOp op) { return AtomicRMWKind::andi; })
27412742
.Case([](arith::MaxUIOp op) { return AtomicRMWKind::maxu; })
27422743
.Case([](arith::MinUIOp op) { return AtomicRMWKind::minu; })
@@ -2806,6 +2807,8 @@ Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
28062807
return arith::OrIOp::create(builder, loc, lhs, rhs);
28072808
case AtomicRMWKind::andi:
28082809
return arith::AndIOp::create(builder, loc, lhs, rhs);
2810+
case AtomicRMWKind::xori:
2811+
return arith::XOrIOp::create(builder, loc, lhs, rhs);
28092812
// TODO: Add remaining reduction operations.
28102813
default:
28112814
(void)emitOptionalError(loc, "Reduction operation type not supported");

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3558,6 +3558,7 @@ LogicalResult AtomicRMWOp::verify() {
35583558
case arith::AtomicRMWKind::minu:
35593559
case arith::AtomicRMWKind::muli:
35603560
case arith::AtomicRMWKind::ori:
3561+
case arith::AtomicRMWKind::xori:
35613562
case arith::AtomicRMWKind::andi:
35623563
if (!llvm::isa<IntegerType>(getValue().getType()))
35633564
return emitOpError() << "with kind '"

mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,9 @@ func.func @atomic_rmw(%I : memref<10xi32>, %ival : i32, %F : memref<10xf32>, %fv
464464
// CHECK: llvm.atomicrmw _or %{{.*}}, %{{.*}} acq_rel
465465
memref.atomic_rmw andi %ival, %I[%i] : (i32, memref<10xi32>) -> i32
466466
// CHECK: llvm.atomicrmw _and %{{.*}}, %{{.*}} acq_rel
467-
// CHECK-INTERFACE-COUNT-13: llvm.atomicrmw
467+
memref.atomic_rmw xori %ival, %I[%i] : (i32, memref<10xi32>) -> i32
468+
// CHECK: llvm.atomicrmw _xor %{{.*}}, %{{.*}} acq_rel
469+
// CHECK-INTERFACE-COUNT-14: llvm.atomicrmw
468470
return
469471
}
470472

0 commit comments

Comments
 (0)