Skip to content

Commit f84aaa6

Browse files
[mlir][Transforms] Dialect conversion: Add flag to dump materialization kind (#119532)
Add a debugging flag to the dialect conversion to dump the materialization kind. This flag is useful to find out whether a missing materialization rule is for source or target materializations. Also add missing test coverage for the `buildMaterializations` flag.
1 parent ba45ac6 commit f84aaa6

File tree

4 files changed

+41
-15
lines changed

4 files changed

+41
-15
lines changed

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1300,6 +1300,12 @@ struct ConversionConfig {
13001300
/// The folding mode to use during conversion.
13011301
DialectConversionFoldingMode foldingMode =
13021302
DialectConversionFoldingMode::BeforePatterns;
1303+
1304+
/// If set to "true", the materialization kind ("source" or "target") will be
1305+
/// attached to "builtin.unrealized_conversion_cast" ops. This flag is useful
1306+
/// for debugging, to find out what kind of materialization rule may be
1307+
/// missing.
1308+
bool attachDebugMaterializationKind = false;
13031309
};
13041310

13051311
//===----------------------------------------------------------------------===//

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1637,6 +1637,11 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
16371637
builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
16381638
UnrealizedConversionCastOp convertOp =
16391639
UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs);
1640+
if (config.attachDebugMaterializationKind) {
1641+
StringRef kindStr =
1642+
kind == MaterializationKind::Source ? "source" : "target";
1643+
convertOp->setAttr("__kind__", builder.getStringAttr(kindStr));
1644+
}
16401645
if (isPureTypeConversion)
16411646
convertOp->setAttr(kPureTypeConversionMarker, builder.getUnitAttr());
16421647

mlir/test/Transforms/test-legalizer.mlir

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=1" -verify-diagnostics %s | FileCheck %s
22
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=1" -verify-diagnostics -profile-actions-to=- %s | FileCheck %s --check-prefix=CHECK-PROFILER
33
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=0" -verify-diagnostics %s | FileCheck %s
4+
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=0 build-materializations=0 attach-debug-materialization-kind=1" -verify-diagnostics %s | FileCheck %s --check-prefix=CHECK-KIND
45

56
// CHECK-PROFILER: "name": "pass-execution", "cat": "PERF", "ph": "B"
67
// CHECK-PROFILER: "name": "apply-conversion", "cat": "PERF", "ph": "B"
@@ -190,9 +191,12 @@ func.func @remap_drop_region() {
190191
// -----
191192

192193
// CHECK-LABEL: func @dropped_input_in_use
194+
// CHECK-KIND-LABEL: func @dropped_input_in_use
193195
func.func @dropped_input_in_use(%arg: i16, %arg2: i64) {
194-
// CHECK-NEXT: "test.cast"{{.*}} : () -> i16
195-
// CHECK-NEXT: "work"{{.*}} : (i16)
196+
// CHECK-NEXT: %[[cast:.*]] = "test.cast"() : () -> i16
197+
// CHECK-NEXT: "work"(%[[cast]]) : (i16)
198+
// CHECK-KIND-NEXT: %[[cast:.*]] = builtin.unrealized_conversion_cast to i16 {__kind__ = "source"}
199+
// CHECK-KIND-NEXT: "work"(%[[cast]]) : (i16)
196200
// expected-remark@+1 {{op 'work' is not legalizable}}
197201
"work"(%arg) : (i16) -> ()
198202
}
@@ -430,6 +434,11 @@ func.func @test_multiple_1_to_n_replacement() {
430434
// CHECK: %[[cast:.*]] = "test.cast"(%[[producer]]) : (i16) -> f64
431435
// CHECK: "test.valid_consumer"(%[[cast]]) : (f64) -> ()
432436
// CHECK: "test.valid_consumer"(%[[producer]]) : (i16) -> ()
437+
// CHECK-KIND-LABEL: func @test_lookup_without_converter
438+
// CHECK-KIND: %[[producer:.*]] = "test.valid_producer"() : () -> i16
439+
// CHECK-KIND: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[producer]] : i16 to f64 {__kind__ = "target"}
440+
// CHECK-KIND: "test.valid_consumer"(%[[cast]]) : (f64) -> ()
441+
// CHECK-KIND: "test.valid_consumer"(%[[producer]]) : (i16) -> ()
433442
func.func @test_lookup_without_converter() {
434443
%0 = "test.replace_with_valid_producer"() {type = i16} : () -> (i64)
435444
"test.replace_with_valid_consumer"(%0) {with_converter} : (i64) -> ()

mlir/test/lib/Dialect/Test/TestPatterns.cpp

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1574,15 +1574,19 @@ struct TestLegalizePatternDriver
15741574
target.addDynamicallyLegalOp<ConvertBlockArgsOp>(
15751575
[](ConvertBlockArgsOp op) { return op.getIsLegal(); });
15761576

1577+
// Set up configuration.
1578+
ConversionConfig config;
1579+
config.allowPatternRollback = allowPatternRollback;
1580+
config.foldingMode = foldingMode;
1581+
config.buildMaterializations = buildMaterializations;
1582+
config.attachDebugMaterializationKind = attachDebugMaterializationKind;
1583+
DumpNotifications dumpNotifications;
1584+
config.listener = &dumpNotifications;
1585+
15771586
// Handle a partial conversion.
15781587
if (mode == ConversionMode::Partial) {
15791588
DenseSet<Operation *> unlegalizedOps;
1580-
ConversionConfig config;
1581-
config.allowPatternRollback = allowPatternRollback;
1582-
DumpNotifications dumpNotifications;
1583-
config.listener = &dumpNotifications;
15841589
config.unlegalizedOps = &unlegalizedOps;
1585-
config.foldingMode = foldingMode;
15861590
if (failed(applyPartialConversion(getOperation(), target,
15871591
std::move(patterns), config))) {
15881592
getOperation()->emitRemark() << "applyPartialConversion failed";
@@ -1600,11 +1604,6 @@ struct TestLegalizePatternDriver
16001604
return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal");
16011605
});
16021606

1603-
ConversionConfig config;
1604-
config.allowPatternRollback = allowPatternRollback;
1605-
DumpNotifications dumpNotifications;
1606-
config.foldingMode = foldingMode;
1607-
config.listener = &dumpNotifications;
16081607
if (failed(applyFullConversion(getOperation(), target,
16091608
std::move(patterns), config))) {
16101609
getOperation()->emitRemark() << "applyFullConversion failed";
@@ -1617,9 +1616,6 @@ struct TestLegalizePatternDriver
16171616

16181617
// Analyze the convertible operations.
16191618
DenseSet<Operation *> legalizedOps;
1620-
ConversionConfig config;
1621-
config.foldingMode = foldingMode;
1622-
config.allowPatternRollback = allowPatternRollback;
16231619
config.legalizableOps = &legalizedOps;
16241620
if (failed(applyAnalysisConversion(getOperation(), target,
16251621
std::move(patterns), config)))
@@ -1658,6 +1654,16 @@ struct TestLegalizePatternDriver
16581654
Option<bool> allowPatternRollback{*this, "allow-pattern-rollback",
16591655
llvm::cl::desc("Allow pattern rollback"),
16601656
llvm::cl::init(true)};
1657+
Option<bool> attachDebugMaterializationKind{
1658+
*this, "attach-debug-materialization-kind",
1659+
llvm::cl::desc(
1660+
"Attach materialization kind to unrealized_conversion_cast ops"),
1661+
llvm::cl::init(false)};
1662+
Option<bool> buildMaterializations{
1663+
*this, "build-materializations",
1664+
llvm::cl::desc(
1665+
"If set to 'false', leave unrealized_conversion_cast ops in place"),
1666+
llvm::cl::init(true)};
16611667
};
16621668
} // namespace
16631669

0 commit comments

Comments
 (0)