Skip to content

Commit f8c62c9

Browse files
krzysz00AWoloszyn
authored andcommitted
Integrate llvm-project@faf5d747f174cc (#20828)
Integrate torch-mlir@389541fb9ddd Integrate stablehlo@5837b2a6ce # Dropped reverts - Drop revert of the 1:N dialect conversion removal since all dependencies have migrated - Drop revert of the APIntParameter error since all dependencies have migrated - Drop revert of allowing function type conversion to fail since all dependencies have migrated - Drop local modifications to torch-mlir and stablehlo, those are now clean submodules (they're no longer needed now that we've dropped reverts and they've migrated) # Continued reverts - We still have a revert of upstream #137930 since it's not clear the roccertness issue is resolved (that's the SDWA cndmask thing, see llvm/llvm-project#138766 ) - We still have a revert of upstream #133231 since that could still be breaking tests # Changes - Rename bufferization.to_memref to bufferization.to_buffer everewhere - Swap all getSource() on Transfer*Op to getBase() - Rename the memref argument of TransferGatherOp to `base` to match the transfer interface - Remove argument materialization calls - In the one case where this wasn't trivial, migrate to a target materilazition since that's what upstream advice suggested - Handle (in any way that seemed appropriate) the failures that eraseArguments() and eraseResults() can now have - Slightly reshuffle the LLVMCPU pipeline so that there's an `affine-expand-index-ops` after the last `FoldMemRefAliasOps` call, because `FoldMemRefAliasOps` now creates `affine.linearize_index` and `affine.delinearize_index` which don't seem to lower to LLVM right on their own. See llvm/llvm-project#138930 - Update narrow type emulation tests to account for correctness fixes in the lineraized shape determination widget - Update vectorization tests to account for changes in ee47454bb8be and update the attention tiling test to account for an unknown change - Update StableHLO rewrites to account for accuracy arguments
1 parent 26f2677 commit f8c62c9

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+205
-165
lines changed

compiler/plugins/input/StableHLO/Conversion/CHLODecompositionPatterns.td

Lines changed: 47 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ def ConstantLikePosInfValue : NativeCodeCall<
3838
def ConstantLikeNegInfValue : NativeCodeCall<
3939
"::mlir::iree_compiler::stablehlo::getConstantLikeInfValue($_builder, $_loc, $0, /*negative=*/true)">;
4040

41+
def STABLEHLO_DEFAULT_RESULT_ACCURACY :
42+
ConstantAttr<StableHLO_ResultAccuracyAttr, "::mlir::stablehlo::ResultAccuracyMode::DEFAULT">;
43+
4144
//===----------------------------------------------------------------------===//
4245
// Unary op patterns.
4346
//===----------------------------------------------------------------------===//
@@ -64,7 +67,8 @@ def : Pat<(CHLO_AcosOp NonComplexElementType:$input),
6467
(StableHLO_SubtractOp
6568
(ConstantLike<"1"> $input),
6669
(StableHLO_MulOp $input, $input)
67-
)
70+
),
71+
STABLEHLO_DEFAULT_RESULT_ACCURACY
6872
),
6973
(StableHLO_AddOp
7074
(ConstantLike<"1"> $input),
@@ -96,15 +100,17 @@ def : Pat<(CHLO_AcoshOp NonComplexElementType:$input),
96100
(StableHLO_CompareOp
97101
$input,
98102
(StableHLO_SqrtOp
99-
(ConstantLikeMaxFiniteValue $input)
103+
(ConstantLikeMaxFiniteValue $input),
104+
STABLEHLO_DEFAULT_RESULT_ACCURACY
100105
),
101106
StableHLO_ComparisonDirectionValue<"GE">,
102107
(STABLEHLO_DEFAULT_COMPARISON_TYPE)
103108
),
104109
(StableHLO_AddOp
105-
(StableHLO_LogOp $input),
110+
(StableHLO_LogOp $input, STABLEHLO_DEFAULT_RESULT_ACCURACY),
106111
(StableHLO_LogOp
107-
(ConstantLike<"2"> $input)
112+
(ConstantLike<"2"> $input),
113+
STABLEHLO_DEFAULT_RESULT_ACCURACY
108114
)
109115
),
110116
(StableHLO_LogOp
@@ -120,9 +126,11 @@ def : Pat<(CHLO_AcoshOp NonComplexElementType:$input),
120126
(ConstantLike<"-1"> $input),
121127
$input
122128
)
123-
)
129+
),
130+
STABLEHLO_DEFAULT_RESULT_ACCURACY
124131
)
125-
)
132+
),
133+
STABLEHLO_DEFAULT_RESULT_ACCURACY
126134
)
127135
)
128136
)>;
@@ -148,9 +156,11 @@ def : Pat<(CHLO_AcoshOp ComplexElementType:$input),
148156
$input,
149157
(ConstantLike<"1"> $input)
150158
)
151-
)
159+
),
160+
STABLEHLO_DEFAULT_RESULT_ACCURACY
152161
)
153-
)
162+
),
163+
STABLEHLO_DEFAULT_RESULT_ACCURACY
154164
)>;
155165

156166

@@ -167,7 +177,8 @@ def : Pat<(CHLO_AsinOp $input),
167177
(StableHLO_SubtractOp
168178
(ConstantLike<"1"> $input),
169179
(StableHLO_MulOp $input, $input)
170-
)
180+
),
181+
STABLEHLO_DEFAULT_RESULT_ACCURACY
171182
)
172183
)
173184
)
@@ -200,17 +211,20 @@ def : Pat<(CHLO_AsinhOp NonComplexElementType:$input),
200211
(StableHLO_CompareOp
201212
(StableHLO_AbsOp $input),
202213
(StableHLO_SqrtOp
203-
(ConstantLikeMaxFiniteValue $input)
214+
(ConstantLikeMaxFiniteValue $input),
215+
STABLEHLO_DEFAULT_RESULT_ACCURACY
204216
),
205217
StableHLO_ComparisonDirectionValue<"GE">,
206218
(STABLEHLO_DEFAULT_COMPARISON_TYPE)
207219
),
208220
(StableHLO_AddOp
209221
(StableHLO_LogOp
210-
(StableHLO_AbsOp $input)
222+
(StableHLO_AbsOp $input),
223+
STABLEHLO_DEFAULT_RESULT_ACCURACY
211224
),
212225
(StableHLO_LogOp
213-
(ConstantLike<"2"> $input)
226+
(ConstantLike<"2"> $input),
227+
STABLEHLO_DEFAULT_RESULT_ACCURACY
214228
)
215229
),
216230
(StableHLO_SelectOp
@@ -236,12 +250,14 @@ def : Pat<(CHLO_AsinhOp NonComplexElementType:$input),
236250
(StableHLO_AbsOp $input)
237251
),
238252
(ConstantLike<"1"> $input)
239-
)
253+
),
254+
STABLEHLO_DEFAULT_RESULT_ACCURACY
240255
)
241256
)
242257
)
243258
)
244-
)
259+
),
260+
STABLEHLO_DEFAULT_RESULT_ACCURACY
245261
),
246262
(StableHLO_LogOp
247263
(StableHLO_AddOp
@@ -253,9 +269,11 @@ def : Pat<(CHLO_AsinhOp NonComplexElementType:$input),
253269
(StableHLO_AbsOp $input)
254270
),
255271
(ConstantLike<"1"> $input)
256-
)
272+
),
273+
STABLEHLO_DEFAULT_RESULT_ACCURACY
257274
)
258-
)
275+
),
276+
STABLEHLO_DEFAULT_RESULT_ACCURACY
259277
)
260278
)
261279
)
@@ -276,9 +294,11 @@ def : Pat<(CHLO_AsinhOp ComplexElementType:$input),
276294
(StableHLO_AddOp
277295
(StableHLO_MulOp $input, $input),
278296
(ConstantLike<"1"> $input)
279-
)
297+
),
298+
STABLEHLO_DEFAULT_RESULT_ACCURACY
280299
)
281-
)
300+
),
301+
STABLEHLO_DEFAULT_RESULT_ACCURACY
282302
)>;
283303

284304
// Express `atan` as
@@ -303,9 +323,10 @@ def : Pat<(CHLO_AtanhOp NonComplexElementType:$input),
303323
(ConstantLike<"NAN"> $input),
304324
(StableHLO_MulOp
305325
(StableHLO_SubtractOp
306-
(StableHLO_Log1pOp $input),
326+
(StableHLO_Log1pOp $input, STABLEHLO_DEFAULT_RESULT_ACCURACY),
307327
(StableHLO_Log1pOp
308-
(StableHLO_NegOp $input)
328+
(StableHLO_NegOp $input),
329+
STABLEHLO_DEFAULT_RESULT_ACCURACY
309330
)
310331
),
311332
(ConstantLike<"0.5"> $input)
@@ -321,9 +342,10 @@ def : Pat<(CHLO_AtanhOp NonComplexElementType:$input),
321342
def : Pat<(CHLO_AtanhOp ComplexElementType:$input),
322343
(StableHLO_MulOp
323344
(StableHLO_SubtractOp
324-
(StableHLO_Log1pOp $input),
345+
(StableHLO_Log1pOp $input, STABLEHLO_DEFAULT_RESULT_ACCURACY),
325346
(StableHLO_Log1pOp
326-
(StableHLO_NegOp $input)
347+
(StableHLO_NegOp $input),
348+
STABLEHLO_DEFAULT_RESULT_ACCURACY
327349
)
328350
),
329351
(ConstantLike<"0.5"> $input)
@@ -365,8 +387,8 @@ def : Pat<(CHLO_IsNegInfOp NonComplexElementType:$input),
365387
// sine(x) / cosine(x)
366388
def : Pat<(CHLO_TanOp NonComplexElementType:$input),
367389
(StableHLO_DivOp
368-
(StableHLO_SineOp $input),
369-
(StableHLO_CosineOp $input)
390+
(StableHLO_SineOp $input, STABLEHLO_DEFAULT_RESULT_ACCURACY),
391+
(StableHLO_CosineOp $input, STABLEHLO_DEFAULT_RESULT_ACCURACY)
370392
)>;
371393

372394

@@ -376,7 +398,7 @@ def : Pat<(CHLO_TanOp ComplexElementType:$input),
376398
(StableHLO_DivOp
377399
(StableHLO_ComplexOp
378400
(CHLO_TanOp:$tan (StableHLO_RealOp $input)),
379-
(StableHLO_TanhOp:$tanh (StableHLO_ImagOp $input))),
401+
(StableHLO_TanhOp:$tanh (StableHLO_ImagOp $input), STABLEHLO_DEFAULT_RESULT_ACCURACY)),
380402
(StableHLO_ComplexOp
381403
(ConstantLike<"1.0"> $tan),
382404
(StableHLO_NegOp (StableHLO_MulOp $tan, $tanh)))

compiler/plugins/input/StableHLO/Conversion/Preprocessing/ComplexLoweringPatterns.td

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ include "stablehlo/dialect/StablehloOps.td"
1414
class ConstantSplat<string value> : NativeCodeCall<
1515
"::mlir::iree_compiler::stablehlo::getSplat(&$_builder, $0, " # value # ")">;
1616

17+
def STABLEHLO_DEFAULT_RESULT_ACCURACY :
18+
ConstantAttr<StableHLO_ResultAccuracyAttr, "::mlir::stablehlo::ResultAccuracyMode::DEFAULT">;
19+
1720
//===----------------------------------------------------------------------===//
1821
// Binary op patterns.
1922
//===----------------------------------------------------------------------===//
@@ -68,42 +71,43 @@ def : Pat<(StableHLO_AbsOp HLO_ComplexTensor:$val),
6871
(StableHLO_SqrtOp
6972
(StableHLO_AddOp
7073
(StableHLO_MulOp (StableHLO_RealOp:$real $val), $real),
71-
(StableHLO_MulOp (StableHLO_ImagOp:$imag $val), $imag)))>;
74+
(StableHLO_MulOp (StableHLO_ImagOp:$imag $val), $imag)),
75+
STABLEHLO_DEFAULT_RESULT_ACCURACY)>;
7276

7377
// Can deconstruct sin(a + ib) as follows:
7478
// sin(a) * cosh(b) + icos(a) * sinh(b)
7579
// sinh(b) = (e^x - e^-x) / 2
7680
// cosh(b) = (e^x + e^-x) / 2
77-
def : Pat<(StableHLO_SineOp HLO_ComplexTensor:$val),
81+
def : Pat<(StableHLO_SineOp HLO_ComplexTensor:$val, $accuracy),
7882
(StableHLO_ComplexOp
7983
(StableHLO_DivOp
8084
(StableHLO_MulOp
81-
(StableHLO_SineOp (StableHLO_RealOp:$real $val)),
85+
(StableHLO_SineOp (StableHLO_RealOp:$real $val), $accuracy),
8286
(StableHLO_AddOp
83-
(StableHLO_ExpOp:$exp (StableHLO_ImagOp:$imag $val)),
84-
(StableHLO_ExpOp:$nexp (StableHLO_NegOp $imag)))),
87+
(StableHLO_ExpOp:$exp (StableHLO_ImagOp:$imag $val), $accuracy),
88+
(StableHLO_ExpOp:$nexp (StableHLO_NegOp $imag), $accuracy))),
8589
(StableHLO_ConstantOp : $two (ConstantSplat<"2.0"> $real))),
8690
(StableHLO_DivOp
8791
(StableHLO_MulOp
88-
(StableHLO_CosineOp $real),
92+
(StableHLO_CosineOp $real, $accuracy),
8993
(StableHLO_SubtractOp $exp, $nexp)), $two))>;
9094

9195
// Can deconstruct cos(a + ib) as follows:
9296
// cos(a) * cosh(b) - isin(a) * sinh(b)
9397
// sinh(b) = (e^x - e^-x) / 2
9498
// cosh(b) = (e^x + e^-x) / 2
95-
def : Pat<(StableHLO_CosineOp HLO_ComplexTensor:$val),
99+
def : Pat<(StableHLO_CosineOp HLO_ComplexTensor:$val, $accuracy),
96100
(StableHLO_ComplexOp
97101
(StableHLO_DivOp
98102
(StableHLO_MulOp
99-
(StableHLO_CosineOp (StableHLO_RealOp:$real $val)),
103+
(StableHLO_CosineOp (StableHLO_RealOp:$real $val), $accuracy),
100104
(StableHLO_AddOp
101-
(StableHLO_ExpOp:$exp (StableHLO_ImagOp:$imag $val)),
102-
(StableHLO_ExpOp:$nexp (StableHLO_NegOp $imag)))),
105+
(StableHLO_ExpOp:$exp (StableHLO_ImagOp:$imag $val), $accuracy),
106+
(StableHLO_ExpOp:$nexp (StableHLO_NegOp $imag), $accuracy))),
103107
(StableHLO_ConstantOp : $two (ConstantSplat<"2.0"> $real))),
104108
(StableHLO_DivOp
105109
(StableHLO_MulOp
106-
(StableHLO_SineOp $real),
110+
(StableHLO_SineOp $real, $accuracy),
107111
(StableHLO_SubtractOp $nexp, $exp)), $two))>;
108112

109113
// Exponential can be lowered to an exponential on the real component and a
@@ -114,12 +118,12 @@ def : Pat<(StableHLO_CosineOp HLO_ComplexTensor:$val),
114118
class StableHLO_ComparisonDirectionValue<string enumStr> :
115119
ConstantAttr<StableHLO_ComparisonDirectionAttr, "::mlir::stablehlo::ComparisonDirection::" # enumStr>;
116120

117-
def : Pat<(StableHLO_ExpOp HLO_ComplexTensor:$val),
121+
def : Pat<(StableHLO_ExpOp HLO_ComplexTensor:$val, $accuracy),
118122
(StableHLO_ComplexOp
119123
(StableHLO_MulOp
120-
(StableHLO_CosineOp (StableHLO_ImagOp:$imag $val)),
121-
(StableHLO_ExpOp:$exp (StableHLO_RealOp:$real $val))),
122-
(StableHLO_MulOp (StableHLO_SineOp $imag), $exp))>;
124+
(StableHLO_CosineOp (StableHLO_ImagOp:$imag $val), $accuracy),
125+
(StableHLO_ExpOp:$exp (StableHLO_RealOp:$real $val), $accuracy)),
126+
(StableHLO_MulOp (StableHLO_SineOp $imag, $accuracy), $exp))>;
123127

124128
foreach pair = [[StableHLO_ComparisonDirectionValue<"NE">, StableHLO_OrOp],
125129
[StableHLO_ComparisonDirectionValue<"EQ">, StableHLO_AndOp]] in {

compiler/plugins/input/StableHLO/Conversion/StableHLOToIREEInputDialects.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,6 @@ struct ConvertStableHloToIreeInputDialects final
507507

508508
std::unique_ptr<TypeConverter> typeConverter =
509509
std::make_unique<::mlir::stablehlo::LinalgTypeConverter>();
510-
typeConverter->addArgumentMaterialization(scalarToTensor);
511510
typeConverter->addSourceMaterialization(scalarToTensor);
512511
typeConverter->addTargetMaterialization(scalarToTensor);
513512

compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgExt.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,6 @@ struct StableHloToStdTypeConverter final : TypeConverter {
105105
addConversion(convertRank0TensorToScalar);
106106
addConversion(convertIntegerToSignless);
107107

108-
addArgumentMaterialization(materializeCast);
109108
addSourceMaterialization(materializeCast);
110109
addTargetMaterialization(materializeCast);
111110
}

compiler/src/iree/compiler/Codegen/Common/BufferizationAnalysis.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ static void hasDestructiveUpdatePattern(Value source, BufferizationPlan &plan) {
371371
return insertSliceOp.getDest();
372372
}
373373
if (auto transferWriteOp = dyn_cast<vector::TransferWriteOp>(op)) {
374-
return transferWriteOp.getSource();
374+
return transferWriteOp.getBase();
375375
}
376376
return nullptr;
377377
};
@@ -380,7 +380,7 @@ static void hasDestructiveUpdatePattern(Value source, BufferizationPlan &plan) {
380380
return extractSliceOp.getSource();
381381
}
382382
if (auto transferReadOp = dyn_cast<vector::TransferReadOp>(op)) {
383-
return transferReadOp.getSource();
383+
return transferReadOp.getBase();
384384
}
385385
return nullptr;
386386
};
@@ -561,27 +561,27 @@ LogicalResult createTensorEquivalenceClasses(mlir::FunctionOpInterface funcOp,
561561
.Case<vector::TransferReadOp>(
562562
[&](vector::TransferReadOp transferReadOp) {
563563
if (llvm::isa<RankedTensorType>(
564-
transferReadOp.getSource().getType())) {
565-
plan.insert(transferReadOp.getSource());
564+
transferReadOp.getBase().getType())) {
565+
plan.insert(transferReadOp.getBase());
566566
}
567567
return success();
568568
})
569569
.Case<vector::TransferWriteOp>(
570570
[&](vector::TransferWriteOp transferWriteOp) {
571571
if (!llvm::isa<RankedTensorType>(
572-
transferWriteOp.getSource().getType())) {
572+
transferWriteOp.getBase().getType())) {
573573
return success();
574574
}
575575
return analyseDestructiveUpdateOp(
576-
transferWriteOp, nullptr, transferWriteOp.getSource(),
576+
transferWriteOp, nullptr, transferWriteOp.getBase(),
577577
transferWriteOp.getResult(), plan);
578578
})
579579
.Case<scf::IfOp>(
580580
[&](scf::IfOp ifOp) { return analyseScfIfOp(ifOp, plan); })
581581
.Case<scf::ForOp>(
582582
[&](scf::ForOp forOp) { return analyseScfForOp(forOp, plan); })
583583
.Case<scf::YieldOp, tensor::EmptyOp, tensor::DimOp, tensor::ExtractOp,
584-
tensor::GenerateOp, tensor::PadOp, bufferization::ToMemrefOp,
584+
tensor::GenerateOp, tensor::PadOp, bufferization::ToBufferOp,
585585
bufferization::AllocTensorOp>(
586586
[&](Operation *op) { return success(); })
587587
.Default([&](Operation *op) -> LogicalResult {
@@ -609,8 +609,8 @@ LogicalResult createTensorEquivalenceClasses(mlir::FunctionOpInterface funcOp,
609609
return;
610610
}
611611
if (auto vectorWriteOp = dyn_cast<vector::TransferWriteOp>(updateOp)) {
612-
if (isa<RankedTensorType>(vectorWriteOp.getSource().getType())) {
613-
hasDestructiveUpdatePattern(vectorWriteOp.getSource(), plan);
612+
if (isa<RankedTensorType>(vectorWriteOp.getBase().getType())) {
613+
hasDestructiveUpdatePattern(vectorWriteOp.getBase(), plan);
614614
}
615615
}
616616
});

compiler/src/iree/compiler/Codegen/Common/ConvertBf16ArithToF32.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,6 @@ template <typename SourceType, typename TargetType>
104104
struct FloatTypeConverter
105105
: public PrimitiveTypeConverter<SourceType, TargetType> {
106106
explicit FloatTypeConverter() {
107-
this->addArgumentMaterialization(convertRankedFloat);
108107
this->addSourceMaterialization(convertRankedFloat);
109108
this->addTargetMaterialization(convertRankedFloat);
110109
}

compiler/src/iree/compiler/Codegen/Common/ConvertBf16ToUInt16Buffers.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,6 @@ struct ConvertBf16ToUInt16BuffersPass final
286286
MLIRContext *ctx = &getContext();
287287

288288
Bf16EmulationConverter typeConverter;
289-
typeConverter.addArgumentMaterialization(materializeArithBitcast);
290289
typeConverter.addTargetMaterialization(materializeArithBitcast);
291290
typeConverter.addSourceMaterialization(materializeArithBitcast);
292291

0 commit comments

Comments
 (0)