Skip to content

Commit ca84e9e

Browse files
[mlir][CF] Add structural type conversion patterns (#165629)
Add structural type conversion patterns for CF dialect ops. These patterns are similar to the SCF structural type conversion patterns. This commit adds missing functionality and is in preparation of #165180, which changes the way blocks are converted. (Only entry blocks are converted.)
1 parent 21bcd00 commit ca84e9e

File tree

6 files changed

+248
-0
lines changed

6 files changed

+248
-0
lines changed
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
//===- StructuralTypeConversions.h - CF Type Conversions --------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_CONTROL_FLOW_TRANSFORMS_STRUCTURAL_TYPE_CONVERSIONS_H
10+
#define MLIR_DIALECT_CONTROL_FLOW_TRANSFORMS_STRUCTURAL_TYPE_CONVERSIONS_H
11+
12+
#include "mlir/IR/PatternMatch.h"
13+
14+
namespace mlir {
15+
16+
class ConversionTarget;
17+
class TypeConverter;
18+
19+
namespace cf {
20+
21+
/// Populates patterns for CF structural type conversions and sets up the
22+
/// provided ConversionTarget with the appropriate legality configuration for
23+
/// the ops to get converted properly.
24+
///
25+
/// A "structural" type conversion is one where the underlying ops are
26+
/// completely agnostic to the actual types involved and simply need to update
27+
/// their types. An example of this is cf.br -- the cf.br op needs to update
28+
/// its types accordingly to the TypeConverter, but otherwise does not care
29+
/// what type conversions are happening.
30+
void populateCFStructuralTypeConversionsAndLegality(
31+
const TypeConverter &typeConverter, RewritePatternSet &patterns,
32+
ConversionTarget &target, PatternBenefit benefit = 1);
33+
34+
/// Similar to `populateCFStructuralTypeConversionsAndLegality` but does not
35+
/// populate the conversion target.
36+
void populateCFStructuralTypeConversions(const TypeConverter &typeConverter,
37+
RewritePatternSet &patterns,
38+
PatternBenefit benefit = 1);
39+
40+
/// Updates the ConversionTarget with dynamic legality of CF operations based
41+
/// on the provided type converter.
42+
void populateCFStructuralTypeConversionTarget(
43+
const TypeConverter &typeConverter, ConversionTarget &target);
44+
45+
} // namespace cf
46+
} // namespace mlir
47+
48+
#endif // MLIR_DIALECT_CONTROL_FLOW_TRANSFORMS_STRUCTURAL_TYPE_CONVERSIONS_H

mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
add_mlir_dialect_library(MLIRControlFlowTransforms
22
BufferDeallocationOpInterfaceImpl.cpp
33
BufferizableOpInterfaceImpl.cpp
4+
StructuralTypeConversions.cpp
45

56
ADDITIONAL_HEADER_DIRS
67
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ControlFlow/Transforms
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
//===- TypeConversion.cpp - Type Conversion of Unstructured Control Flow --===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements a pass to convert MLIR standard and builtin dialects
10+
// into the LLVM IR dialect.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h"
15+
16+
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
17+
#include "mlir/IR/PatternMatch.h"
18+
#include "mlir/Pass/Pass.h"
19+
#include "mlir/Transforms/DialectConversion.h"
20+
21+
using namespace mlir;
22+
23+
namespace {
24+
25+
/// Helper function for converting branch ops. This function converts the
26+
/// signature of the given block. If the new block signature is different from
27+
/// `expectedTypes`, returns "failure".
28+
static FailureOr<Block *> getConvertedBlock(ConversionPatternRewriter &rewriter,
29+
const TypeConverter *converter,
30+
Operation *branchOp, Block *block,
31+
TypeRange expectedTypes) {
32+
assert(converter && "expected non-null type converter");
33+
assert(!block->isEntryBlock() && "entry blocks have no predecessors");
34+
35+
// There is nothing to do if the types already match.
36+
if (block->getArgumentTypes() == expectedTypes)
37+
return block;
38+
39+
// Compute the new block argument types and convert the block.
40+
std::optional<TypeConverter::SignatureConversion> conversion =
41+
converter->convertBlockSignature(block);
42+
if (!conversion)
43+
return rewriter.notifyMatchFailure(branchOp,
44+
"could not compute block signature");
45+
if (expectedTypes != conversion->getConvertedTypes())
46+
return rewriter.notifyMatchFailure(
47+
branchOp,
48+
"mismatch between adaptor operand types and computed block signature");
49+
return rewriter.applySignatureConversion(block, *conversion, converter);
50+
}
51+
52+
/// Flatten the given value ranges into a single vector of values.
53+
static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
54+
SmallVector<Value> result;
55+
for (const ValueRange &vals : values)
56+
llvm::append_range(result, vals);
57+
return result;
58+
}
59+
60+
/// Convert the destination block signature (if necessary) and change the
61+
/// operands of the branch op.
62+
struct BranchOpConversion : public OpConversionPattern<cf::BranchOp> {
63+
using OpConversionPattern<cf::BranchOp>::OpConversionPattern;
64+
65+
LogicalResult
66+
matchAndRewrite(cf::BranchOp op, OneToNOpAdaptor adaptor,
67+
ConversionPatternRewriter &rewriter) const override {
68+
SmallVector<Value> flattenedAdaptor = flattenValues(adaptor.getOperands());
69+
FailureOr<Block *> convertedBlock =
70+
getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(),
71+
TypeRange(ValueRange(flattenedAdaptor)));
72+
if (failed(convertedBlock))
73+
return failure();
74+
rewriter.replaceOpWithNewOp<cf::BranchOp>(op, flattenedAdaptor,
75+
*convertedBlock);
76+
return success();
77+
}
78+
};
79+
80+
/// Convert the destination block signatures (if necessary) and change the
81+
/// operands of the branch op.
82+
struct CondBranchOpConversion : public OpConversionPattern<cf::CondBranchOp> {
83+
using OpConversionPattern<cf::CondBranchOp>::OpConversionPattern;
84+
85+
LogicalResult
86+
matchAndRewrite(cf::CondBranchOp op, OneToNOpAdaptor adaptor,
87+
ConversionPatternRewriter &rewriter) const override {
88+
SmallVector<Value> flattenedAdaptorTrue =
89+
flattenValues(adaptor.getTrueDestOperands());
90+
SmallVector<Value> flattenedAdaptorFalse =
91+
flattenValues(adaptor.getFalseDestOperands());
92+
if (!llvm::hasSingleElement(adaptor.getCondition()))
93+
return rewriter.notifyMatchFailure(op,
94+
"expected single element condition");
95+
FailureOr<Block *> convertedTrueBlock =
96+
getConvertedBlock(rewriter, getTypeConverter(), op, op.getTrueDest(),
97+
TypeRange(ValueRange(flattenedAdaptorTrue)));
98+
if (failed(convertedTrueBlock))
99+
return failure();
100+
FailureOr<Block *> convertedFalseBlock =
101+
getConvertedBlock(rewriter, getTypeConverter(), op, op.getFalseDest(),
102+
TypeRange(ValueRange(flattenedAdaptorFalse)));
103+
if (failed(convertedFalseBlock))
104+
return failure();
105+
rewriter.replaceOpWithNewOp<cf::CondBranchOp>(
106+
op, llvm::getSingleElement(adaptor.getCondition()),
107+
flattenedAdaptorTrue, flattenedAdaptorFalse, op.getBranchWeightsAttr(),
108+
*convertedTrueBlock, *convertedFalseBlock);
109+
return success();
110+
}
111+
};
112+
113+
/// Convert the destination block signatures (if necessary) and change the
114+
/// operands of the switch op.
115+
struct SwitchOpConversion : public OpConversionPattern<cf::SwitchOp> {
116+
using OpConversionPattern<cf::SwitchOp>::OpConversionPattern;
117+
118+
LogicalResult
119+
matchAndRewrite(cf::SwitchOp op, OpAdaptor adaptor,
120+
ConversionPatternRewriter &rewriter) const override {
121+
// Get or convert default block.
122+
FailureOr<Block *> convertedDefaultBlock = getConvertedBlock(
123+
rewriter, getTypeConverter(), op, op.getDefaultDestination(),
124+
TypeRange(adaptor.getDefaultOperands()));
125+
if (failed(convertedDefaultBlock))
126+
return failure();
127+
128+
// Get or convert all case blocks.
129+
SmallVector<Block *> caseDestinations;
130+
SmallVector<ValueRange> caseOperands = adaptor.getCaseOperands();
131+
for (auto it : llvm::enumerate(op.getCaseDestinations())) {
132+
Block *b = it.value();
133+
FailureOr<Block *> convertedBlock =
134+
getConvertedBlock(rewriter, getTypeConverter(), op, b,
135+
TypeRange(caseOperands[it.index()]));
136+
if (failed(convertedBlock))
137+
return failure();
138+
caseDestinations.push_back(*convertedBlock);
139+
}
140+
141+
rewriter.replaceOpWithNewOp<cf::SwitchOp>(
142+
op, adaptor.getFlag(), *convertedDefaultBlock,
143+
adaptor.getDefaultOperands(), adaptor.getCaseValuesAttr(),
144+
caseDestinations, caseOperands);
145+
return success();
146+
}
147+
};
148+
149+
} // namespace
150+
151+
void mlir::cf::populateCFStructuralTypeConversions(
152+
const TypeConverter &typeConverter, RewritePatternSet &patterns,
153+
PatternBenefit benefit) {
154+
patterns.add<BranchOpConversion, CondBranchOpConversion, SwitchOpConversion>(
155+
typeConverter, patterns.getContext(), benefit);
156+
}
157+
158+
void mlir::cf::populateCFStructuralTypeConversionTarget(
159+
const TypeConverter &typeConverter, ConversionTarget &target) {
160+
target.addDynamicallyLegalOp<cf::BranchOp, cf::CondBranchOp, cf::SwitchOp>(
161+
[&](Operation *op) { return typeConverter.isLegal(op->getOperands()); });
162+
}
163+
164+
void mlir::cf::populateCFStructuralTypeConversionsAndLegality(
165+
const TypeConverter &typeConverter, RewritePatternSet &patterns,
166+
ConversionTarget &target, PatternBenefit benefit) {
167+
populateCFStructuralTypeConversions(typeConverter, patterns, benefit);
168+
populateCFStructuralTypeConversionTarget(typeConverter, target);
169+
}

mlir/test/Transforms/test-legalize-type-conversion.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,3 +143,25 @@ func.func @test_signature_conversion_no_converter() {
143143
return
144144
}
145145

146+
// -----
147+
148+
// CHECK-LABEL: func @test_unstructured_cf_conversion(
149+
// CHECK-SAME: %[[arg0:.*]]: f64, %[[c:.*]]: i1)
150+
// CHECK: %[[cast1:.*]] = "test.cast"(%[[arg0]]) : (f64) -> f32
151+
// CHECK: "test.foo"(%[[cast1]])
152+
// CHECK: cf.br ^[[bb1:.*]](%[[arg0]] : f64)
153+
// CHECK: ^[[bb1]](%[[arg1:.*]]: f64):
154+
// CHECK: cf.cond_br %[[c]], ^[[bb1]](%[[arg1]] : f64), ^[[bb2:.*]](%[[arg1]] : f64)
155+
// CHECK: ^[[bb2]](%[[arg2:.*]]: f64):
156+
// CHECK: %[[cast2:.*]] = "test.cast"(%[[arg2]]) : (f64) -> f32
157+
// CHECK: "test.bar"(%[[cast2]])
158+
// CHECK: return
159+
func.func @test_unstructured_cf_conversion(%arg0: f32, %c: i1) {
160+
"test.foo"(%arg0) : (f32) -> ()
161+
cf.br ^bb1(%arg0: f32)
162+
^bb1(%arg1: f32):
163+
cf.cond_br %c, ^bb1(%arg1 : f32), ^bb2(%arg1 : f32)
164+
^bb2(%arg2: f32):
165+
"test.bar"(%arg2) : (f32) -> ()
166+
return
167+
}

mlir/test/lib/Dialect/Test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ add_mlir_library(MLIRTestDialect
7171
)
7272
mlir_target_link_libraries(MLIRTestDialect PUBLIC
7373
MLIRControlFlowInterfaces
74+
MLIRControlFlowTransforms
7475
MLIRDataLayoutInterfaces
7576
MLIRDerivedAttributeOpInterface
7677
MLIRDestinationStyleOpInterface

mlir/test/lib/Dialect/Test/TestPatterns.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "TestTypes.h"
1212
#include "mlir/Dialect/Arith/IR/Arith.h"
1313
#include "mlir/Dialect/CommonFolders.h"
14+
#include "mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h"
1415
#include "mlir/Dialect/Func/IR/FuncOps.h"
1516
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
1617
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
@@ -2042,6 +2043,10 @@ struct TestTypeConversionDriver
20422043
});
20432044
converter.addConversion([](IndexType type) { return type; });
20442045
converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &types) {
2046+
if (type.isInteger(1)) {
2047+
// i1 is legal.
2048+
types.push_back(type);
2049+
}
20452050
if (type.isInteger(38)) {
20462051
// i38 is legal.
20472052
types.push_back(type);
@@ -2175,6 +2180,8 @@ struct TestTypeConversionDriver
21752180
converter);
21762181
mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
21772182
converter, patterns, target);
2183+
mlir::cf::populateCFStructuralTypeConversionsAndLegality(converter,
2184+
patterns, target);
21782185

21792186
ConversionConfig config;
21802187
config.allowPatternRollback = allowPatternRollback;

0 commit comments

Comments
 (0)