Skip to content

Commit c61c5d2

Browse files
lhutton1Lallapalloozashubham-armdflavin-arm
authored
[mlir][tosa] Add a pass to narrow i64 to i32 (#165581)
This pass aims to narrow i64 types on TOSA operations to i32. It can be useful for legalizations from various frameworks. It comes with the following options: - "aggressive-rewrite" - This option is typically able to narrow more values, but may impact numerical behaviour if not used carefully. - "convert-function-boundaries" - If enabled, parameters/ results to/from a function may be narrowed. Otherwise, casts are inserted to preserve the I/O of the function. Currently the non aggressive mode is very limited, targeting an argmax -> cast sequence that has been observed during legalization as well as some data layout operations that can always narrow. Support for more operations will be added in the future. Co-authored-by: Vitalii Shutov <[email protected]> Co-authored-by: Shubham <[email protected]> Co-authored-by: Declan Flavin <[email protected]> Signed-off-by: Luke Hutton <[email protected]> Co-authored-by: Vitalii Shutov <[email protected]> Co-authored-by: Shubham <[email protected]> Co-authored-by: Declan Flavin <[email protected]>
1 parent 3d5d32c commit c61c5d2

File tree

5 files changed

+577
-0
lines changed

5 files changed

+577
-0
lines changed

mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,4 +166,27 @@ def TosaAttachTarget : Pass<"tosa-attach-target", "ModuleOp"> {
166166
];
167167
}
168168

169+
def TosaNarrowI64ToI32Pass : Pass<"tosa-narrow-i64-to-i32", "func::FuncOp"> {
170+
let summary = "Narrow I64 TOSA operations to I32";
171+
let description = [{
172+
This pass narrows TOSA operations with 64-bit integer tensor types to
173+
32-bit integer tensor types. This can be useful for backends that do not
174+
support the EXT-INT64 extension of TOSA.
175+
}];
176+
177+
let options = [
178+
Option<"aggressiveRewrite", "aggressive-rewrite", "bool", "false",
179+
"If enabled, all TOSA operations are rewritten, regardless or whether the narrowing"
180+
"is safe. This option may lead to data loss if not used carefully.">,
181+
Option<"convertFunctionBoundaries", "convert-function-boundaries", "bool", "false",
182+
"If enabled, the pass will convert function I/O types as well. Otherwise casts will"
183+
"be inserted at the I/O boundaries.">
184+
];
185+
186+
let dependentDialects = [
187+
"func::FuncDialect",
188+
"tosa::TosaDialect",
189+
];
190+
}
191+
169192
#endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES

mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRTosaTransforms
1212
TosaTypeConverters.cpp
1313
TosaProfileCompliance.cpp
1414
TosaValidation.cpp
15+
TosaNarrowI64ToI32.cpp
1516

1617
ADDITIONAL_HEADER_DIRS
1718
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa/Transforms
Lines changed: 310 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,310 @@
1+
//===- TosaNarrowI64ToI32.cpp ---------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This pass narrows TOSA operations with 64-bit integer tensor types to
10+
// 32-bit integer tensor types. This can be useful for backends that do not
11+
// support the EXT-INT64 extension of TOSA. The pass has two options:
12+
//
13+
// - aggressive-rewrite - If enabled, all TOSA operations are rewritten,
14+
// regardless or whether the narrowing is safe. This option may lead to
15+
// data loss if not used carefully.
16+
// - convert-function-boundaries - If enabled, the pass will convert function
17+
// I/O types as well. Otherwise casts will be inserted at the I/O
18+
// boundaries.
19+
//
20+
//===----------------------------------------------------------------------===//
21+
22+
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
23+
24+
#include "mlir/Dialect/Func/IR/FuncOps.h"
25+
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
26+
#include "mlir/IR/Verifier.h"
27+
#include "mlir/Pass/Pass.h"
28+
29+
namespace mlir {
30+
namespace tosa {
31+
#define GEN_PASS_DEF_TOSANARROWI64TOI32PASS
32+
#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
33+
} // namespace tosa
34+
} // namespace mlir
35+
36+
using namespace mlir;
37+
using namespace mlir::tosa;
38+
39+
namespace {
40+
41+
LogicalResult convertGenericOp(Operation *op, ValueRange operands,
42+
ConversionPatternRewriter &rewriter,
43+
const TypeConverter *typeConverter) {
44+
// Convert types of results
45+
SmallVector<Type, 4> newResults;
46+
if (failed(typeConverter->convertTypes(op->getResultTypes(), newResults)))
47+
return failure();
48+
49+
// Create a new operation state
50+
OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
51+
newResults, {}, op->getSuccessors());
52+
53+
for (const NamedAttribute &namedAttribute : op->getAttrs()) {
54+
const Attribute attribute = namedAttribute.getValue();
55+
56+
// Convert integer attribute type
57+
if (const auto intAttr = dyn_cast<IntegerAttr>(attribute)) {
58+
const std::optional<Attribute> convertedAttribute =
59+
typeConverter->convertTypeAttribute(intAttr.getType(), attribute);
60+
state.addAttribute(namedAttribute.getName(), convertedAttribute.value());
61+
continue;
62+
}
63+
64+
if (const auto typeAttr = dyn_cast<TypeAttr>(attribute)) {
65+
Type type = typeAttr.getValue();
66+
const std::optional<Attribute> convertedAttribute =
67+
typeConverter->convertTypeAttribute(type, attribute);
68+
if (!convertedAttribute)
69+
return rewriter.notifyMatchFailure(op,
70+
"Failed to convert type attribute.");
71+
state.addAttribute(namedAttribute.getName(), convertedAttribute.value());
72+
continue;
73+
}
74+
75+
if (const auto denseElementsAttr = dyn_cast<DenseElementsAttr>(attribute)) {
76+
const Type type = denseElementsAttr.getType();
77+
const std::optional<Attribute> convertedAttribute =
78+
typeConverter->convertTypeAttribute(type, denseElementsAttr);
79+
if (!convertedAttribute)
80+
return rewriter.notifyMatchFailure(
81+
op, "Failed to convert dense elements attribute.");
82+
state.addAttribute(namedAttribute.getName(), convertedAttribute.value());
83+
continue;
84+
}
85+
86+
state.addAttribute(namedAttribute.getName(), attribute);
87+
}
88+
89+
for (Region &region : op->getRegions()) {
90+
Region *newRegion = state.addRegion();
91+
rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin());
92+
if (failed(rewriter.convertRegionTypes(newRegion, *typeConverter)))
93+
return failure();
94+
}
95+
96+
Operation *newOp = rewriter.create(state);
97+
rewriter.replaceOp(op, newOp->getResults());
98+
return success();
99+
}
100+
101+
// ===========================
102+
// Aggressive rewrite patterns
103+
// ===========================
104+
105+
class ConvertGenericOp : public ConversionPattern {
106+
public:
107+
ConvertGenericOp(TypeConverter &typeConverter, MLIRContext *context)
108+
: ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context) {}
109+
110+
LogicalResult
111+
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
112+
ConversionPatternRewriter &rewriter) const final {
113+
if (!isa<tosa::TosaOp>(op))
114+
return rewriter.notifyMatchFailure(
115+
op,
116+
"Support for operations other than TOSA has not been implemented.");
117+
118+
return convertGenericOp(op, operands, rewriter, typeConverter);
119+
}
120+
};
121+
122+
// ===============================
123+
// Bounds checked rewrite patterns
124+
// ===============================
125+
126+
class ConvertArgMaxOpWithBoundsChecking
127+
: public OpConversionPattern<tosa::ArgMaxOp> {
128+
using OpConversionPattern::OpConversionPattern;
129+
130+
LogicalResult
131+
matchAndRewrite(tosa::ArgMaxOp op, OpAdaptor adaptor,
132+
ConversionPatternRewriter &rewriter) const final {
133+
// Output type can be narrowed based on the size of the axis dimension
134+
const int32_t axis = op.getAxis();
135+
const auto inputType = dyn_cast<ShapedType>(adaptor.getInput().getType());
136+
if (!inputType || !inputType.isStaticDim(axis))
137+
return rewriter.notifyMatchFailure(
138+
op, "Requires a static axis dimension for bounds checking.");
139+
const int64_t axisDim = inputType.getDimSize(axis);
140+
if (axisDim >= std::numeric_limits<int32_t>::max())
141+
return rewriter.notifyMatchFailure(
142+
op, "Axis dimension is too large to narrow safely.");
143+
144+
const Type resultType = op.getOutput().getType();
145+
const Type newResultType = typeConverter->convertType(resultType);
146+
rewriter.replaceOpWithNewOp<tosa::ArgMaxOp>(op, newResultType,
147+
adaptor.getInput(), axis);
148+
return success();
149+
}
150+
};
151+
152+
class ConvertCastOpWithBoundsChecking
153+
: public OpConversionPattern<tosa::CastOp> {
154+
using OpConversionPattern::OpConversionPattern;
155+
156+
LogicalResult
157+
matchAndRewrite(tosa::CastOp op, OpAdaptor adaptor,
158+
ConversionPatternRewriter &rewriter) const final {
159+
const auto inputType = dyn_cast<ShapedType>(adaptor.getInput().getType());
160+
const auto resultType = dyn_cast<ShapedType>(op.getResult().getType());
161+
if (!inputType || !resultType)
162+
return failure();
163+
164+
const auto elementInputIntType =
165+
dyn_cast<IntegerType>(inputType.getElementType());
166+
const auto elementResultIntType =
167+
dyn_cast<IntegerType>(resultType.getElementType());
168+
if (elementInputIntType && elementResultIntType &&
169+
elementInputIntType.getWidth() > elementResultIntType.getWidth())
170+
return rewriter.notifyMatchFailure(
171+
op, "Narrowing cast may lead to data loss.");
172+
173+
rewriter.replaceOpWithNewOp<tosa::CastOp>(
174+
op, typeConverter->convertType(resultType), adaptor.getInput());
175+
return success();
176+
}
177+
};
178+
179+
template <typename OpTy>
180+
class ConvertTypedOp : public OpConversionPattern<OpTy> {
181+
using OpConversionPattern<OpTy>::OpConversionPattern;
182+
183+
LogicalResult
184+
matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
185+
ConversionPatternRewriter &rewriter) const final {
186+
return convertGenericOp(op, adaptor.getOperands(), rewriter,
187+
this->getTypeConverter());
188+
}
189+
};
190+
191+
struct TosaNarrowI64ToI32
192+
: public tosa::impl::TosaNarrowI64ToI32PassBase<TosaNarrowI64ToI32> {
193+
public:
194+
explicit TosaNarrowI64ToI32() = default;
195+
explicit TosaNarrowI64ToI32(const TosaNarrowI64ToI32PassOptions &options)
196+
: TosaNarrowI64ToI32() {
197+
this->aggressiveRewrite = options.aggressiveRewrite;
198+
this->convertFunctionBoundaries = options.convertFunctionBoundaries;
199+
}
200+
201+
void runOnOperation() override {
202+
MLIRContext *context = &getContext();
203+
204+
TypeConverter typeConverter;
205+
typeConverter.addConversion([](Type type) -> Type { return type; });
206+
typeConverter.addConversion([](IntegerType type) -> Type {
207+
if (!type.isInteger(64))
208+
return type;
209+
return IntegerType::get(type.getContext(), 32);
210+
});
211+
typeConverter.addConversion(
212+
[&typeConverter](RankedTensorType type) -> Type {
213+
const Type elementType = type.getElementType();
214+
if (!elementType.isInteger(64))
215+
return type;
216+
return RankedTensorType::get(type.getShape(),
217+
typeConverter.convertType(elementType));
218+
});
219+
220+
const auto materializeCast = [](OpBuilder &builder, Type resultType,
221+
ValueRange inputs, Location loc) -> Value {
222+
if (inputs.size() != 1)
223+
return Value();
224+
return tosa::CastOp::create(builder, loc, resultType, inputs.front());
225+
};
226+
typeConverter.addSourceMaterialization(materializeCast);
227+
typeConverter.addTargetMaterialization(materializeCast);
228+
229+
typeConverter.addTypeAttributeConversion(
230+
[](IntegerType type, IntegerAttr attribute) -> Attribute {
231+
const APInt value = attribute.getValue().truncSSat(32);
232+
return IntegerAttr::get(IntegerType::get(type.getContext(), 32),
233+
value);
234+
});
235+
typeConverter.addTypeAttributeConversion(
236+
[&typeConverter](ShapedType type,
237+
DenseIntElementsAttr attr) -> Attribute {
238+
const ShapedType newType =
239+
cast<ShapedType>(typeConverter.convertType(type));
240+
const auto oldElementType = cast<IntegerType>(type.getElementType());
241+
const auto newElementType =
242+
cast<IntegerType>(newType.getElementType());
243+
if (oldElementType.getWidth() == newElementType.getWidth())
244+
return attr;
245+
246+
DenseElementsAttr mapped =
247+
attr.mapValues(newElementType, [&](const APInt &v) {
248+
return v.truncSSat(newElementType.getWidth());
249+
});
250+
return mapped;
251+
});
252+
253+
ConversionTarget target(*context);
254+
target.addDynamicallyLegalDialect<tosa::TosaDialect>(
255+
[&typeConverter](Operation *op) {
256+
return typeConverter.isLegal(op->getResultTypes()) &&
257+
typeConverter.isLegal(op->getOperandTypes());
258+
});
259+
if (convertFunctionBoundaries) {
260+
target.addDynamicallyLegalOp<func::FuncOp>(
261+
[&typeConverter](func::FuncOp op) {
262+
return typeConverter.isSignatureLegal(op.getFunctionType()) &&
263+
typeConverter.isLegal(&op.getBody());
264+
});
265+
target.addDynamicallyLegalOp<func::ReturnOp>([](func::ReturnOp op) {
266+
const FunctionType funcType =
267+
op->getParentOfType<func::FuncOp>().getFunctionType();
268+
return llvm::equal(op.getOperandTypes(), funcType.getResults());
269+
});
270+
} else {
271+
target.addDynamicallyLegalOp<func::FuncOp>(
272+
[](func::FuncOp op) { return true; });
273+
target.addDynamicallyLegalOp<func::ReturnOp>(
274+
[](func::ReturnOp op) { return true; });
275+
}
276+
277+
RewritePatternSet patterns(context);
278+
if (convertFunctionBoundaries) {
279+
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
280+
patterns, typeConverter);
281+
populateReturnOpTypeConversionPattern(patterns, typeConverter);
282+
}
283+
if (aggressiveRewrite) {
284+
patterns.add<ConvertGenericOp>(typeConverter, context);
285+
} else {
286+
// Tensor
287+
patterns.add<ConvertArgMaxOpWithBoundsChecking>(typeConverter, context);
288+
// Data layout
289+
patterns.add<ConvertTypedOp<tosa::ConcatOp>>(typeConverter, context);
290+
patterns.add<ConvertTypedOp<tosa::PadOp>>(typeConverter, context);
291+
patterns.add<ConvertTypedOp<tosa::ReshapeOp>>(typeConverter, context);
292+
patterns.add<ConvertTypedOp<tosa::ReverseOp>>(typeConverter, context);
293+
patterns.add<ConvertTypedOp<tosa::SliceOp>>(typeConverter, context);
294+
patterns.add<ConvertTypedOp<tosa::TileOp>>(typeConverter, context);
295+
patterns.add<ConvertTypedOp<tosa::TransposeOp>>(typeConverter, context);
296+
patterns.add<ConvertTypedOp<tosa::IdentityOp>>(typeConverter, context);
297+
// Type conversion
298+
patterns.add<ConvertCastOpWithBoundsChecking>(typeConverter, context);
299+
// Controlflow
300+
patterns.add<ConvertTypedOp<tosa::IfOp>>(typeConverter, context);
301+
patterns.add<ConvertTypedOp<tosa::WhileOp>>(typeConverter, context);
302+
}
303+
304+
if (failed(
305+
applyFullConversion(getOperation(), target, std::move(patterns))))
306+
signalPassFailure();
307+
}
308+
};
309+
310+
} // namespace

0 commit comments

Comments
 (0)