Skip to content

Commit ca84e9e

Browse files
[mlir][CF] Add structural type conversion patterns (llvm#165629)
Add structural type conversion patterns for CF dialect ops. These patterns are similar to the SCF structural type conversion patterns. This commit adds missing functionality and is in preparation of llvm#165180, which changes the way blocks are converted. (Only entry blocks are converted.)
1 parent 21bcd00 commit ca84e9e

File tree

6 files changed

+248
-0
lines changed

6 files changed

+248
-0
lines changed
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
//===- StructuralTypeConversions.h - CF Type Conversions --------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_CONTROL_FLOW_TRANSFORMS_STRUCTURAL_TYPE_CONVERSIONS_H
10+
#define MLIR_DIALECT_CONTROL_FLOW_TRANSFORMS_STRUCTURAL_TYPE_CONVERSIONS_H
11+
12+
#include "mlir/IR/PatternMatch.h"
13+
14+
namespace mlir {
15+
16+
class ConversionTarget;
17+
class TypeConverter;
18+
19+
namespace cf {
20+
21+
/// Populates patterns for CF structural type conversions and sets up the
22+
/// provided ConversionTarget with the appropriate legality configuration for
23+
/// the ops to get converted properly.
24+
///
25+
/// A "structural" type conversion is one where the underlying ops are
26+
/// completely agnostic to the actual types involved and simply need to update
27+
/// their types. An example of this is cf.br -- the cf.br op needs to update
28+
/// its types accordingly to the TypeConverter, but otherwise does not care
29+
/// what type conversions are happening.
30+
void populateCFStructuralTypeConversionsAndLegality(
31+
const TypeConverter &typeConverter, RewritePatternSet &patterns,
32+
ConversionTarget &target, PatternBenefit benefit = 1);
33+
34+
/// Similar to `populateCFStructuralTypeConversionsAndLegality` but does not
35+
/// populate the conversion target.
36+
void populateCFStructuralTypeConversions(const TypeConverter &typeConverter,
37+
RewritePatternSet &patterns,
38+
PatternBenefit benefit = 1);
39+
40+
/// Updates the ConversionTarget with dynamic legality of CF operations based
41+
/// on the provided type converter.
42+
void populateCFStructuralTypeConversionTarget(
43+
const TypeConverter &typeConverter, ConversionTarget &target);
44+
45+
} // namespace cf
46+
} // namespace mlir
47+
48+
#endif // MLIR_DIALECT_CONTROL_FLOW_TRANSFORMS_STRUCTURAL_TYPE_CONVERSIONS_H

mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
add_mlir_dialect_library(MLIRControlFlowTransforms
22
BufferDeallocationOpInterfaceImpl.cpp
33
BufferizableOpInterfaceImpl.cpp
4+
StructuralTypeConversions.cpp
45

56
ADDITIONAL_HEADER_DIRS
67
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ControlFlow/Transforms
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
//===- TypeConversion.cpp - Type Conversion of Unstructured Control Flow --===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements a pass to convert MLIR standard and builtin dialects
10+
// into the LLVM IR dialect.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h"
15+
16+
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
17+
#include "mlir/IR/PatternMatch.h"
18+
#include "mlir/Pass/Pass.h"
19+
#include "mlir/Transforms/DialectConversion.h"
20+
21+
using namespace mlir;
22+
23+
namespace {
24+
25+
/// Helper function for converting branch ops. This function converts the
26+
/// signature of the given block. If the new block signature is different from
27+
/// `expectedTypes`, returns "failure".
28+
static FailureOr<Block *> getConvertedBlock(ConversionPatternRewriter &rewriter,
29+
const TypeConverter *converter,
30+
Operation *branchOp, Block *block,
31+
TypeRange expectedTypes) {
32+
assert(converter && "expected non-null type converter");
33+
assert(!block->isEntryBlock() && "entry blocks have no predecessors");
34+
35+
// There is nothing to do if the types already match.
36+
if (block->getArgumentTypes() == expectedTypes)
37+
return block;
38+
39+
// Compute the new block argument types and convert the block.
40+
std::optional<TypeConverter::SignatureConversion> conversion =
41+
converter->convertBlockSignature(block);
42+
if (!conversion)
43+
return rewriter.notifyMatchFailure(branchOp,
44+
"could not compute block signature");
45+
if (expectedTypes != conversion->getConvertedTypes())
46+
return rewriter.notifyMatchFailure(
47+
branchOp,
48+
"mismatch between adaptor operand types and computed block signature");
49+
return rewriter.applySignatureConversion(block, *conversion, converter);
50+
}
51+
52+
/// Flatten the given value ranges into a single vector of values.
53+
static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
54+
SmallVector<Value> result;
55+
for (const ValueRange &vals : values)
56+
llvm::append_range(result, vals);
57+
return result;
58+
}
59+
60+
/// Convert the destination block signature (if necessary) and change the
61+
/// operands of the branch op.
62+
struct BranchOpConversion : public OpConversionPattern<cf::BranchOp> {
63+
using OpConversionPattern<cf::BranchOp>::OpConversionPattern;
64+
65+
LogicalResult
66+
matchAndRewrite(cf::BranchOp op, OneToNOpAdaptor adaptor,
67+
ConversionPatternRewriter &rewriter) const override {
68+
SmallVector<Value> flattenedAdaptor = flattenValues(adaptor.getOperands());
69+
FailureOr<Block *> convertedBlock =
70+
getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(),
71+
TypeRange(ValueRange(flattenedAdaptor)));
72+
if (failed(convertedBlock))
73+
return failure();
74+
rewriter.replaceOpWithNewOp<cf::BranchOp>(op, flattenedAdaptor,
75+
*convertedBlock);
76+
return success();
77+
}
78+
};
79+
80+
/// Convert the destination block signatures (if necessary) and change the
81+
/// operands of the branch op.
82+
struct CondBranchOpConversion : public OpConversionPattern<cf::CondBranchOp> {
83+
using OpConversionPattern<cf::CondBranchOp>::OpConversionPattern;
84+
85+
LogicalResult
86+
matchAndRewrite(cf::CondBranchOp op, OneToNOpAdaptor adaptor,
87+
ConversionPatternRewriter &rewriter) const override {
88+
SmallVector<Value> flattenedAdaptorTrue =
89+
flattenValues(adaptor.getTrueDestOperands());
90+
SmallVector<Value> flattenedAdaptorFalse =
91+
flattenValues(adaptor.getFalseDestOperands());
92+
if (!llvm::hasSingleElement(adaptor.getCondition()))
93+
return rewriter.notifyMatchFailure(op,
94+
"expected single element condition");
95+
FailureOr<Block *> convertedTrueBlock =
96+
getConvertedBlock(rewriter, getTypeConverter(), op, op.getTrueDest(),
97+
TypeRange(ValueRange(flattenedAdaptorTrue)));
98+
if (failed(convertedTrueBlock))
99+
return failure();
100+
FailureOr<Block *> convertedFalseBlock =
101+
getConvertedBlock(rewriter, getTypeConverter(), op, op.getFalseDest(),
102+
TypeRange(ValueRange(flattenedAdaptorFalse)));
103+
if (failed(convertedFalseBlock))
104+
return failure();
105+
rewriter.replaceOpWithNewOp<cf::CondBranchOp>(
106+
op, llvm::getSingleElement(adaptor.getCondition()),
107+
flattenedAdaptorTrue, flattenedAdaptorFalse, op.getBranchWeightsAttr(),
108+
*convertedTrueBlock, *convertedFalseBlock);
109+
return success();
110+
}
111+
};
112+
113+
/// Convert the destination block signatures (if necessary) and change the
114+
/// operands of the switch op.
115+
struct SwitchOpConversion : public OpConversionPattern<cf::SwitchOp> {
116+
using OpConversionPattern<cf::SwitchOp>::OpConversionPattern;
117+
118+
LogicalResult
119+
matchAndRewrite(cf::SwitchOp op, OpAdaptor adaptor,
120+
ConversionPatternRewriter &rewriter) const override {
121+
// Get or convert default block.
122+
FailureOr<Block *> convertedDefaultBlock = getConvertedBlock(
123+
rewriter, getTypeConverter(), op, op.getDefaultDestination(),
124+
TypeRange(adaptor.getDefaultOperands()));
125+
if (failed(convertedDefaultBlock))
126+
return failure();
127+
128+
// Get or convert all case blocks.
129+
SmallVector<Block *> caseDestinations;
130+
SmallVector<ValueRange> caseOperands = adaptor.getCaseOperands();
131+
for (auto it : llvm::enumerate(op.getCaseDestinations())) {
132+
Block *b = it.value();
133+
FailureOr<Block *> convertedBlock =
134+
getConvertedBlock(rewriter, getTypeConverter(), op, b,
135+
TypeRange(caseOperands[it.index()]));
136+
if (failed(convertedBlock))
137+
return failure();
138+
caseDestinations.push_back(*convertedBlock);
139+
}
140+
141+
rewriter.replaceOpWithNewOp<cf::SwitchOp>(
142+
op, adaptor.getFlag(), *convertedDefaultBlock,
143+
adaptor.getDefaultOperands(), adaptor.getCaseValuesAttr(),
144+
caseDestinations, caseOperands);
145+
return success();
146+
}
147+
};
148+
149+
} // namespace
150+
151+
void mlir::cf::populateCFStructuralTypeConversions(
152+
const TypeConverter &typeConverter, RewritePatternSet &patterns,
153+
PatternBenefit benefit) {
154+
patterns.add<BranchOpConversion, CondBranchOpConversion, SwitchOpConversion>(
155+
typeConverter, patterns.getContext(), benefit);
156+
}
157+
158+
void mlir::cf::populateCFStructuralTypeConversionTarget(
159+
const TypeConverter &typeConverter, ConversionTarget &target) {
160+
target.addDynamicallyLegalOp<cf::BranchOp, cf::CondBranchOp, cf::SwitchOp>(
161+
[&](Operation *op) { return typeConverter.isLegal(op->getOperands()); });
162+
}
163+
164+
void mlir::cf::populateCFStructuralTypeConversionsAndLegality(
165+
const TypeConverter &typeConverter, RewritePatternSet &patterns,
166+
ConversionTarget &target, PatternBenefit benefit) {
167+
populateCFStructuralTypeConversions(typeConverter, patterns, benefit);
168+
populateCFStructuralTypeConversionTarget(typeConverter, target);
169+
}

mlir/test/Transforms/test-legalize-type-conversion.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,3 +143,25 @@ func.func @test_signature_conversion_no_converter() {
143143
return
144144
}
145145

146+
// -----
147+
148+
// CHECK-LABEL: func @test_unstructured_cf_conversion(
149+
// CHECK-SAME: %[[arg0:.*]]: f64, %[[c:.*]]: i1)
150+
// CHECK: %[[cast1:.*]] = "test.cast"(%[[arg0]]) : (f64) -> f32
151+
// CHECK: "test.foo"(%[[cast1]])
152+
// CHECK: cf.br ^[[bb1:.*]](%[[arg0]] : f64)
153+
// CHECK: ^[[bb1]](%[[arg1:.*]]: f64):
154+
// CHECK: cf.cond_br %[[c]], ^[[bb1]](%[[arg1]] : f64), ^[[bb2:.*]](%[[arg1]] : f64)
155+
// CHECK: ^[[bb2]](%[[arg2:.*]]: f64):
156+
// CHECK: %[[cast2:.*]] = "test.cast"(%[[arg2]]) : (f64) -> f32
157+
// CHECK: "test.bar"(%[[cast2]])
158+
// CHECK: return
159+
func.func @test_unstructured_cf_conversion(%arg0: f32, %c: i1) {
160+
"test.foo"(%arg0) : (f32) -> ()
161+
cf.br ^bb1(%arg0: f32)
162+
^bb1(%arg1: f32):
163+
cf.cond_br %c, ^bb1(%arg1 : f32), ^bb2(%arg1 : f32)
164+
^bb2(%arg2: f32):
165+
"test.bar"(%arg2) : (f32) -> ()
166+
return
167+
}

mlir/test/lib/Dialect/Test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ add_mlir_library(MLIRTestDialect
7171
)
7272
mlir_target_link_libraries(MLIRTestDialect PUBLIC
7373
MLIRControlFlowInterfaces
74+
MLIRControlFlowTransforms
7475
MLIRDataLayoutInterfaces
7576
MLIRDerivedAttributeOpInterface
7677
MLIRDestinationStyleOpInterface

mlir/test/lib/Dialect/Test/TestPatterns.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "TestTypes.h"
1212
#include "mlir/Dialect/Arith/IR/Arith.h"
1313
#include "mlir/Dialect/CommonFolders.h"
14+
#include "mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h"
1415
#include "mlir/Dialect/Func/IR/FuncOps.h"
1516
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
1617
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
@@ -2042,6 +2043,10 @@ struct TestTypeConversionDriver
20422043
});
20432044
converter.addConversion([](IndexType type) { return type; });
20442045
converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &types) {
2046+
if (type.isInteger(1)) {
2047+
// i1 is legal.
2048+
types.push_back(type);
2049+
}
20452050
if (type.isInteger(38)) {
20462051
// i38 is legal.
20472052
types.push_back(type);
@@ -2175,6 +2180,8 @@ struct TestTypeConversionDriver
21752180
converter);
21762181
mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
21772182
converter, patterns, target);
2183+
mlir::cf::populateCFStructuralTypeConversionsAndLegality(converter,
2184+
patterns, target);
21782185

21792186
ConversionConfig config;
21802187
config.allowPatternRollback = allowPatternRollback;

0 commit comments

Comments
 (0)