Skip to content

Commit 3315fb2

Browse files
Muzammiluddin-Syed-ECEhhkit
authored andcommitted
Carries the revert of llvm/llvm-project@b4c31dc. Fixups: llvm/llvm-project@51a1aab - Use `populateExpansionPatterns` instead of `populateExpandXXXXPattern` llvm/llvm-project@685a98c - Support array result for emitc.member and emitc.member_of_ptr. Support the new result type when possible, otherwise perform a dynamic cast to the previously supported type. llvm/llvm-project@13471e1 - Avoid requirement to specify enum name for enum attributes i.e ```mlir rounding_mode = "SINGLE_ROUND" // becomes rounding_mode = SINGLE_ROUND. ``` --------- Signed-off-by: Muzammiluddin Syed <[email protected]> Signed-off-by: Ivan Ho <[email protected]>
1 parent 7ed6488 commit 3315fb2

File tree

8 files changed

+67
-60
lines changed

8 files changed

+67
-60
lines changed

compiler/src/iree/compiler/Codegen/Common/MathTransformPass.cpp

Lines changed: 17 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -22,36 +22,24 @@ namespace mlir::iree_compiler {
2222
static void populateMathFunctionsRewritePatterns(
2323
RewritePatternSet &patterns,
2424
const std::function<bool(StringRef)> &predicate) {
25-
if (predicate(math::TanOp::getOperationName())) {
26-
populateExpandTanPattern(patterns);
27-
}
28-
if (predicate(math::SinhOp::getOperationName())) {
29-
populateExpandSinhPattern(patterns);
30-
}
31-
if (predicate(math::CoshOp::getOperationName())) {
32-
populateExpandCoshPattern(patterns);
33-
}
34-
if (predicate(math::AsinhOp::getOperationName())) {
35-
populateExpandAsinhPattern(patterns);
36-
}
37-
if (predicate(math::AcoshOp::getOperationName())) {
38-
populateExpandAcoshPattern(patterns);
39-
}
40-
if (predicate(math::AtanhOp::getOperationName())) {
41-
populateExpandAtanhPattern(patterns);
42-
}
43-
if (predicate(math::PowFOp::getOperationName())) {
44-
populateExpandPowFPattern(patterns);
45-
}
46-
if (predicate(math::FPowIOp::getOperationName())) {
47-
populateExpandFPowIPattern(patterns);
48-
}
49-
if (predicate(math::Exp2Op::getOperationName())) {
50-
populateExpandExp2FPattern(patterns);
51-
}
52-
if (predicate(math::RoundEvenOp::getOperationName())) {
53-
populateExpandRoundEvenPattern(patterns);
25+
llvm::SmallVector<StringRef> opNames,
26+
opFullNames = {math::TanOp::getOperationName(),
27+
math::SinhOp::getOperationName(),
28+
math::CoshOp::getOperationName(),
29+
math::AsinhOp::getOperationName(),
30+
math::AcoshOp::getOperationName(),
31+
math::AtanhOp::getOperationName(),
32+
math::PowFOp::getOperationName(),
33+
math::FPowIOp::getOperationName(),
34+
math::Exp2Op::getOperationName(),
35+
math::RoundEvenOp::getOperationName()};
36+
size_t prefix = math::MathDialect::getDialectNamespace().size() + 1;
37+
for (auto name : opFullNames) {
38+
if (predicate(name)) {
39+
opNames.push_back(name.drop_front(prefix));
40+
}
5441
}
42+
math::populateExpansionPatterns(patterns, opNames);
5543
}
5644

5745
static bool predicateRewrite(StringRef name,

compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1049,12 +1049,17 @@ void ConvertToLLVMPass::runOnOperation() {
10491049
if (use32BitImpl) {
10501050
patterns.add<ExpandMulSIExtended>(patterns.getContext(), /*benefit=*/1024);
10511051
}
1052-
1052+
auto populateTanhPatterns = [](RewritePatternSet &p) {
1053+
StringRef fname = math::TanhOp::getOperationName();
1054+
size_t prefix = math::MathDialect::getDialectNamespace().size() + 1;
1055+
StringRef opName = fname.drop_front(prefix);
1056+
math::populateExpansionPatterns(p, /*OpMnemonics=*/{opName});
1057+
};
10531058
LLVMConversionTarget target(getContext());
10541059
populateAffineToStdConversionPatterns(patterns);
10551060
populateSCFToControlFlowConversionPatterns(patterns);
10561061
cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns);
1057-
populateExpandTanhPattern(patterns);
1062+
populateTanhPatterns(patterns);
10581063

10591064
populateComplexToLLVMConversionPatterns(typeConverter, patterns);
10601065
populateMathToLLVMConversionPatterns(typeConverter, patterns);

compiler/src/iree/compiler/Codegen/LLVMCPU/test/apply_scale_lowering.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ hal.executable private @apply_scale_no_vector_feature {
2727
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) : memref<2xi32>
2828
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) : memref<2xi32>
2929
%2 = vector.load %0[%c0] : memref<2xi32>, vector<2xi32>
30-
%3 = tosa.apply_scale %2, %cst, %cst_0 {rounding_mode = "SINGLE_ROUND"} : (vector<2xi32>, vector<2xi32>, vector<2xi8>) -> vector<2xi32>
30+
%3 = tosa.apply_scale %2, %cst, %cst_0 {rounding_mode = SINGLE_ROUND} : (vector<2xi32>, vector<2xi32>, vector<2xi8>) -> vector<2xi32>
3131
vector.store %3, %1[%c0] : memref<2xi32>, vector<2xi32>
3232
return
3333
}
@@ -72,7 +72,7 @@ hal.executable private @apply_scale_v {
7272
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) : memref<2xi32>
7373
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) : memref<2xi32>
7474
%2 = vector.load %0[%c0] : memref<2xi32>, vector<2xi32>
75-
%3 = tosa.apply_scale %2, %cst, %cst_0 {rounding_mode = "SINGLE_ROUND"} : (vector<2xi32>, vector<2xi32>, vector<2xi8>) -> vector<2xi32>
75+
%3 = tosa.apply_scale %2, %cst, %cst_0 {rounding_mode = SINGLE_ROUND} : (vector<2xi32>, vector<2xi32>, vector<2xi8>) -> vector<2xi32>
7676
vector.store %3, %1[%c0] : memref<2xi32>, vector<2xi32>
7777
return
7878
}
@@ -115,7 +115,7 @@ hal.executable private @apply_scale_zve64x {
115115
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) : memref<2xi32>
116116
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) : memref<2xi32>
117117
%2 = vector.load %0[%c0] : memref<2xi32>, vector<2xi32>
118-
%3 = tosa.apply_scale %2, %cst, %cst_0 {rounding_mode = "SINGLE_ROUND"} : (vector<2xi32>, vector<2xi32>, vector<2xi8>) -> vector<2xi32>
118+
%3 = tosa.apply_scale %2, %cst, %cst_0 {rounding_mode = SINGLE_ROUND} : (vector<2xi32>, vector<2xi32>, vector<2xi8>) -> vector<2xi32>
119119
vector.store %3, %1[%c0] : memref<2xi32>, vector<2xi32>
120120
return
121121
}
@@ -158,7 +158,7 @@ hal.executable private @apply_scale_zve32x {
158158
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) : memref<2xi32>
159159
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) : memref<2xi32>
160160
%2 = vector.load %0[%c0] : memref<2xi32>, vector<2xi32>
161-
%3 = tosa.apply_scale %2, %cst, %cst_0 {rounding_mode = "SINGLE_ROUND"} : (vector<2xi32>, vector<2xi32>, vector<2xi8>) -> vector<2xi32>
161+
%3 = tosa.apply_scale %2, %cst, %cst_0 {rounding_mode = SINGLE_ROUND} : (vector<2xi32>, vector<2xi32>, vector<2xi8>) -> vector<2xi32>
162162
vector.store %3, %1[%c0] : memref<2xi32>, vector<2xi32>
163163
return
164164
}
@@ -208,7 +208,7 @@ hal.executable private @apply_scale_zve32f {
208208
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) : memref<2xi32>
209209
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) : memref<2xi32>
210210
%2 = vector.load %0[%c0] : memref<2xi32>, vector<2xi32>
211-
%3 = tosa.apply_scale %2, %cst, %cst_0 {rounding_mode = "SINGLE_ROUND"} : (vector<2xi32>, vector<2xi32>, vector<2xi8>) -> vector<2xi32>
211+
%3 = tosa.apply_scale %2, %cst, %cst_0 {rounding_mode = SINGLE_ROUND} : (vector<2xi32>, vector<2xi32>, vector<2xi8>) -> vector<2xi32>
212212
vector.store %3, %1[%c0] : memref<2xi32>, vector<2xi32>
213213
return
214214
}

compiler/src/iree/compiler/Codegen/VMVX/test/select_lowering_strategy.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ func.func @fusion_quant_matmul_generic() attributes {hal.executable.target = #ex
127127
%16 = arith.muli %in_1, %c-128_i32 : i32
128128
%17 = arith.subi %in_0, %16 : i32
129129
%18 = arith.addi %in, %17 : i32
130-
%19 = tosa.apply_scale %18, %c1101627623_i32, %c36_i8 {rounding_mode = "DOUBLE_ROUND"} : (i32, i32, i8) -> i32
130+
%19 = tosa.apply_scale %18, %c1101627623_i32, %c36_i8 {rounding_mode = DOUBLE_ROUND} : (i32, i32, i8) -> i32
131131
%20 = arith.addi %19, %c-128_i32 : i32
132132
%21 = arith.cmpi slt, %20, %c-128_i32 : i32
133133
%22 = arith.select %21, %c-128_i32, %20 : i32

compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2282,13 +2282,14 @@ class ImportOpConverter {
22822282
auto ctx = builder.getContext();
22832283

22842284
// byteSpan = call.<memberName>;
2285-
auto byteSpan = builder
2286-
.create<emitc::MemberOp>(
2287-
loc,
2288-
emitc::LValueType::get(emitc::OpaqueType::get(
2289-
ctx, "iree_byte_span_t")),
2290-
memberName, call)
2291-
.getResult();
2285+
TypedValue<mlir::Type> byteSpan =
2286+
builder
2287+
.create<emitc::MemberOp>(
2288+
loc,
2289+
emitc::LValueType::get(
2290+
emitc::OpaqueType::get(ctx, "iree_byte_span_t")),
2291+
memberName, call)
2292+
.getResult();
22922293

22932294
// alloca_(0) returns NULL in some configurations on Windows. Make sure to
22942295
// allocate at least one byte to get a valid pointer.
@@ -2500,25 +2501,32 @@ class ImportOpConverter {
25002501
auto ctx = builder.getContext();
25012502

25022503
// RETURN_IF_ERROR(import->module->begin_call(import->module, stack, call));
2503-
auto importModule = builder.create<emitc::MemberOfPtrOp>(
2504+
auto im = builder.create<emitc::MemberOfPtrOp>(
25042505
loc,
25052506
/*type=*/
25062507
emitc::LValueType::get(emitc::PointerType::get(
25072508
emitc::OpaqueType::get(ctx, "iree_vm_module_t"))),
25082509
/*memberName=*/"module",
25092510
/*operand=*/import);
2511+
auto importModule = dyn_cast<TypedValue<emitc::LValueType>>(im.getResult());
2512+
if (!importModule) {
2513+
return failure();
2514+
}
25102515

25112516
// EmitC can't emit function pointers, so we need to fallback to a typedef
25122517
// currently. This and the `EMITC_CALL_INDIRECT` macro can be replaced with
25132518
// a new `emitc.call_indirect` op once it has been added upstream.
25142519
emitc::OpaqueType type = emitc::OpaqueType::get(ctx, "begin_call_t");
25152520

2516-
auto beginCall =
2521+
auto bc =
25172522
builder
25182523
.create<emitc::MemberOfPtrOp>(loc, emitc::LValueType::get(type),
25192524
"begin_call", importModule)
25202525
.getResult();
2521-
2526+
auto beginCall = dyn_cast<TypedValue<emitc::LValueType>>(bc);
2527+
if (!beginCall) {
2528+
return failure();
2529+
}
25222530
returnIfError(
25232531
/*rewriter=*/builder,
25242532
/*location=*/loc,

compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/EmitCBuilders.cpp

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -233,10 +233,12 @@ void structDefinition(OpBuilder builder, Location location,
233233
}
234234

235235
Value structMember(OpBuilder builder, Location location, Type type,
236-
StringRef memberName,
237-
TypedValue<emitc::LValueType> operand) {
238-
TypedValue<emitc::LValueType> member = builder.create<emitc::MemberOp>(
239-
location, emitc::LValueType::get(type), memberName, operand);
236+
StringRef memberName, TypedValue<mlir::Type> operand) {
237+
TypedValue<mlir::Type> member =
238+
builder
239+
.create<emitc::MemberOp>(location, emitc::LValueType::get(type),
240+
memberName, operand)
241+
.getResult();
240242
return builder.create<emitc::LoadOp>(location, type, member).getResult();
241243
}
242244

@@ -246,12 +248,14 @@ structMemberAddress(OpBuilder builder, Location location,
246248
TypedValue<emitc::LValueType> operand) {
247249
auto member = builder.create<emitc::MemberOp>(location, type.getPointee(),
248250
memberName, operand);
249-
return addressOf(builder, location, member.getResult());
251+
return addressOf(
252+
builder, location,
253+
llvm::cast<TypedValue<emitc::LValueType>>(member.getResult()));
250254
}
251255

252256
void structMemberAssign(OpBuilder builder, Location location,
253-
StringRef memberName,
254-
TypedValue<emitc::LValueType> operand, Value data) {
257+
StringRef memberName, TypedValue<mlir::Type> operand,
258+
Value data) {
255259
Value member = builder.create<emitc::MemberOp>(
256260
location, emitc::LValueType::get(data.getType()), memberName, operand);
257261
builder.create<emitc::AssignOp>(location, member, data);
@@ -260,7 +264,7 @@ void structMemberAssign(OpBuilder builder, Location location,
260264
Value structPtrMember(OpBuilder builder, Location location, Type type,
261265
StringRef memberName,
262266
TypedValue<emitc::LValueType> operand) {
263-
TypedValue<emitc::LValueType> member = builder.create<emitc::MemberOfPtrOp>(
267+
TypedValue<mlir::Type> member = builder.create<emitc::MemberOfPtrOp>(
264268
location, emitc::LValueType::get(type), memberName, operand);
265269
return builder.create<emitc::LoadOp>(location, type, member).getResult();
266270
}
@@ -271,7 +275,9 @@ structPtrMemberAddress(OpBuilder builder, Location location,
271275
TypedValue<emitc::LValueType> operand) {
272276
auto member = builder.create<emitc::MemberOfPtrOp>(
273277
location, emitc::LValueType::get(type.getPointee()), memberName, operand);
274-
return addressOf(builder, location, member.getResult());
278+
return addressOf(
279+
builder, location,
280+
llvm::cast<TypedValue<emitc::LValueType>>(member.getResult()));
275281
}
276282

277283
void structPtrMemberAssign(OpBuilder builder, Location location,

compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/EmitCBuilders.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,16 +103,16 @@ void structDefinition(OpBuilder builder, Location location,
103103
StringRef structName, ArrayRef<StructField> fields);
104104

105105
Value structMember(OpBuilder builder, Location location, Type type,
106-
StringRef memberName, TypedValue<emitc::LValueType> operand);
106+
StringRef memberName, TypedValue<mlir::Type> operand);
107107

108108
TypedValue<emitc::PointerType>
109109
structMemberAddress(OpBuilder builder, Location location,
110110
emitc::PointerType type, StringRef memberName,
111111
TypedValue<emitc::LValueType> operand);
112112

113113
void structMemberAssign(OpBuilder builder, Location location,
114-
StringRef memberName,
115-
TypedValue<emitc::LValueType> operand, Value data);
114+
StringRef memberName, TypedValue<mlir::Type> operand,
115+
Value data);
116116

117117
Value structPtrMember(OpBuilder builder, Location location, Type type,
118118
StringRef memberName,

third_party/llvm-project

Submodule llvm-project updated 2469 files

0 commit comments

Comments
 (0)