Skip to content

Commit 8deccc1

Browse files
committed
Support unary ops; test with LIT instead of C++
1 parent 9c201ee commit 8deccc1

File tree

6 files changed

+203
-76
lines changed

6 files changed

+203
-76
lines changed

mlir/include/mlir/Dialect/CommonFolders.h

Lines changed: 80 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -149,14 +149,14 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
149149
return operands[1];
150150
}
151151

152-
auto getResultType = [](Attribute attr) -> Type {
152+
auto getAttrType = [](Attribute attr) -> Type {
153153
if (auto typed = dyn_cast_or_null<TypedAttr>(attr))
154154
return typed.getType();
155155
return {};
156156
};
157157

158-
Type lhsType = getResultType(operands[0]);
159-
Type rhsType = getResultType(operands[1]);
158+
Type lhsType = getAttrType(operands[0]);
159+
Type rhsType = getAttrType(operands[1]);
160160
if (!lhsType || !rhsType)
161161
return {};
162162
if (lhsType != rhsType)
@@ -202,16 +202,20 @@ Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
202202

203203
/// Performs constant folding `calculate` with element-wise behavior on the one
204204
/// attributes in `operands` and returns the result if possible.
205+
/// Uses `resultType` for the type of the returned attribute.
205206
/// Optional PoisonAttr template argument allows to specify 'poison' attribute
206207
/// which will be directly propagated to result.
207-
template <class AttrElementT,
208+
template <class AttrElementT, //
208209
class ElementValueT = typename AttrElementT::ValueType,
209210
class PoisonAttr = ub::PoisonAttr,
211+
class ResultAttrElementT = AttrElementT,
212+
class ResultElementValueT = typename ResultAttrElementT::ValueType,
210213
class CalculationT =
211-
function_ref<std::optional<ElementValueT>(ElementValueT)>>
214+
function_ref<std::optional<ResultElementValueT>(ElementValueT)>>
212215
Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
216+
Type resultType,
213217
CalculationT &&calculate) {
214-
if (!llvm::getSingleElement(operands))
218+
if (!resultType || !llvm::getSingleElement(operands))
215219
return {};
216220

217221
static_assert(
@@ -229,7 +233,7 @@ Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
229233
auto res = calculate(op.getValue());
230234
if (!res)
231235
return {};
232-
return AttrElementT::get(op.getType(), *res);
236+
return ResultAttrElementT::get(resultType, *res);
233237
}
234238
if (isa<SplatElementsAttr>(operands[0])) {
235239
// Both operands are splats so we can avoid expanding the values out and
@@ -239,7 +243,7 @@ Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
239243
auto elementResult = calculate(op.getSplatValue<ElementValueT>());
240244
if (!elementResult)
241245
return {};
242-
return DenseElementsAttr::get(op.getType(), *elementResult);
246+
return DenseElementsAttr::get(cast<ShapedType>(resultType), *elementResult);
243247
} else if (isa<ElementsAttr>(operands[0])) {
244248
// Operands are ElementsAttr-derived; perform an element-wise fold by
245249
// expanding the values.
@@ -249,27 +253,89 @@ Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
249253
if (!maybeOpIt)
250254
return {};
251255
auto opIt = *maybeOpIt;
252-
SmallVector<ElementValueT> elementResults;
256+
SmallVector<ResultElementValueT> elementResults;
253257
elementResults.reserve(op.getNumElements());
254258
for (size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt) {
255259
auto elementResult = calculate(*opIt);
256260
if (!elementResult)
257261
return {};
258262
elementResults.push_back(*elementResult);
259263
}
260-
return DenseElementsAttr::get(op.getShapedType(), elementResults);
264+
return DenseElementsAttr::get(cast<ShapedType>(resultType), elementResults);
261265
}
262266
return {};
263267
}
264268

265-
template <class AttrElementT,
269+
/// Performs constant folding `calculate` with element-wise behavior on the one
270+
/// attributes in `operands` and returns the result if possible.
271+
/// Uses the operand element type for the element type of the returned
272+
/// attribute.
273+
/// Optional PoisonAttr template argument allows to specify 'poison' attribute
274+
/// which will be directly propagated to result.
275+
template <class AttrElementT, //
276+
class ElementValueT = typename AttrElementT::ValueType,
277+
class PoisonAttr = ub::PoisonAttr,
278+
class ResultAttrElementT = AttrElementT,
279+
class ResultElementValueT = typename ResultAttrElementT::ValueType,
280+
class CalculationT =
281+
function_ref<std::optional<ResultElementValueT>(ElementValueT)>>
282+
Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
283+
CalculationT &&calculate) {
284+
if (!llvm::getSingleElement(operands))
285+
return {};
286+
287+
static_assert(
288+
std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
289+
"PoisonAttr is undefined, either add a dependency on UB dialect or pass "
290+
"void as template argument to opt-out from poison semantics.");
291+
if constexpr (!std::is_void_v<PoisonAttr>) {
292+
if (isa<PoisonAttr>(operands[0]))
293+
return operands[0];
294+
}
295+
296+
auto getAttrType = [](Attribute attr) -> Type {
297+
if (auto typed = dyn_cast_or_null<TypedAttr>(attr))
298+
return typed.getType();
299+
return {};
300+
};
301+
302+
Type operandType = getAttrType(operands[0]);
303+
if (!operandType)
304+
return {};
305+
306+
return constFoldUnaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
307+
ResultAttrElementT, ResultElementValueT,
308+
CalculationT>(
309+
operands, operandType, std::forward<CalculationT>(calculate));
310+
}
311+
312+
template <class AttrElementT, //
313+
class ElementValueT = typename AttrElementT::ValueType,
314+
class PoisonAttr = ub::PoisonAttr,
315+
class ResultAttrElementT = AttrElementT,
316+
class ResultElementValueT = typename ResultAttrElementT::ValueType,
317+
class CalculationT = function_ref<ResultElementValueT(ElementValueT)>>
318+
Attribute constFoldUnaryOp(ArrayRef<Attribute> operands, Type resultType,
319+
CalculationT &&calculate) {
320+
return constFoldUnaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
321+
ResultAttrElementT>(
322+
operands, resultType,
323+
[&](ElementValueT a) -> std::optional<ResultElementValueT> {
324+
return calculate(a);
325+
});
326+
}
327+
328+
template <class AttrElementT, //
266329
class ElementValueT = typename AttrElementT::ValueType,
267330
class PoisonAttr = ub::PoisonAttr,
268-
class CalculationT = function_ref<ElementValueT(ElementValueT)>>
331+
class ResultAttrElementT = AttrElementT,
332+
class ResultElementValueT = typename ResultAttrElementT::ValueType,
333+
class CalculationT = function_ref<ResultElementValueT(ElementValueT)>>
269334
Attribute constFoldUnaryOp(ArrayRef<Attribute> operands,
270335
CalculationT &&calculate) {
271-
return constFoldUnaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
272-
operands, [&](ElementValueT a) -> std::optional<ElementValueT> {
336+
return constFoldUnaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
337+
ResultAttrElementT>(
338+
operands, [&](ElementValueT a) -> std::optional<ResultElementValueT> {
273339
return calculate(a);
274340
});
275341
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// RUN: mlir-opt %s --test-fold-type-converting-op --split-input-file | FileCheck %s
2+
3+
// CHECK-LABEL: @test_fold_unary_op_f32_to_si32(
4+
func.func @test_fold_unary_op_f32_to_si32() -> tensor<4x2xsi32> {
5+
// CHECK-NEXT: %[[POSITIVE_ONE:.*]] = arith.constant dense<1> : tensor<4x2xsi32>
6+
// CHECK-NEXT: return %[[POSITIVE_ONE]] : tensor<4x2xsi32>
7+
%operand = arith.constant dense<5.1> : tensor<4x2xf32>
8+
%sign = test.sign %operand : (tensor<4x2xf32>) -> tensor<4x2xsi32>
9+
return %sign : tensor<4x2xsi32>
10+
}
11+
12+
// -----
13+
14+
// CHECK-LABEL: @test_fold_binary_op_f32_to_i1(
15+
func.func @test_fold_binary_op_f32_to_i1() -> tensor<8xi1> {
16+
// CHECK-NEXT: %[[FALSE:.*]] = arith.constant dense<false> : tensor<8xi1>
17+
// CHECK-NEXT: return %[[FALSE]] : tensor<8xi1>
18+
%lhs = arith.constant dense<5.1> : tensor<8xf32>
19+
%rhs = arith.constant dense<4.2> : tensor<8xf32>
20+
%less_than = test.less_than %lhs, %rhs : (tensor<8xf32>, tensor<8xf32>) -> tensor<8xi1>
21+
return %less_than : tensor<8xi1>
22+
}

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1169,6 +1169,26 @@ def OpP : TEST_Op<"op_p"> {
11691169
let results = (outs I32);
11701170
}
11711171

1172+
// Test constant-folding a pattern that maps (F32, F32) -> I32.
1173+
def SignOp : TEST_Op<"sign", [SameOperandsAndResultShape]> {
1174+
let arguments = (ins RankedTensorOf<[F32]>:$operand);
1175+
let results = (outs RankedTensorOf<[SI32]>:$result);
1176+
1177+
let assemblyFormat = [{
1178+
$operand attr-dict `:` functional-type(operands, results)
1179+
}];
1180+
}
1181+
1182+
// Test constant-folding a pattern that maps (F32, F32) -> I32.
1183+
def LessThanOp : TEST_Op<"less_than", [SameOperandsAndResultShape]> {
1184+
let arguments = (ins RankedTensorOf<[F32]>:$lhs, RankedTensorOf<[F32]>:$rhs);
1185+
let results = (outs RankedTensorOf<[I1]>:$result);
1186+
1187+
let assemblyFormat = [{
1188+
$lhs `,` $rhs attr-dict `:` functional-type(operands, results)
1189+
}];
1190+
}
1191+
11721192
// Test same operand name enforces equality condition check.
11731193
def TestEqualArgsPattern : Pat<(OpN $a, $a), (OpO $a)>;
11741194

mlir/test/lib/Dialect/Test/TestPatterns.cpp

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "TestOps.h"
1111
#include "TestTypes.h"
1212
#include "mlir/Dialect/Arith/IR/Arith.h"
13+
#include "mlir/Dialect/CommonFolders.h"
1314
#include "mlir/Dialect/Func/IR/FuncOps.h"
1415
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
1516
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -202,6 +203,66 @@ struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> {
202203
}
203204
};
204205

206+
struct FoldSignOpF32ToSI32 : public OpRewritePattern<test::SignOp> {
207+
using OpRewritePattern<test::SignOp>::OpRewritePattern;
208+
209+
LogicalResult matchAndRewrite(test::SignOp op,
210+
PatternRewriter &rewriter) const override {
211+
if (op->getNumOperands() != 1 || op->getNumResults() != 1)
212+
return failure();
213+
214+
TypedAttr operandAttr;
215+
matchPattern(op->getOperand(0), m_Constant(&operandAttr));
216+
if (!operandAttr)
217+
return failure();
218+
219+
TypedAttr res = cast_or_null<TypedAttr>(
220+
constFoldUnaryOp<FloatAttr, FloatAttr::ValueType, void, IntegerAttr>(
221+
operandAttr, op.getType(), [](APFloat operand) -> APSInt {
222+
static const APFloat zero(0.0f);
223+
int operandSign = 0;
224+
if (operand != zero)
225+
operandSign = (operand < zero) ? -1 : +1;
226+
return APSInt(APInt(32, operandSign), false);
227+
}));
228+
if (!res)
229+
return failure();
230+
231+
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, res);
232+
return success();
233+
}
234+
};
235+
236+
struct FoldLessThanOpF32ToI1 : public OpRewritePattern<test::LessThanOp> {
237+
using OpRewritePattern<test::LessThanOp>::OpRewritePattern;
238+
239+
LogicalResult matchAndRewrite(test::LessThanOp op,
240+
PatternRewriter &rewriter) const override {
241+
if (op->getNumOperands() != 2 || op->getNumResults() != 1)
242+
return failure();
243+
244+
TypedAttr lhsAttr;
245+
TypedAttr rhsAttr;
246+
matchPattern(op->getOperand(0), m_Constant(&lhsAttr));
247+
matchPattern(op->getOperand(1), m_Constant(&rhsAttr));
248+
249+
if (!lhsAttr || !rhsAttr)
250+
return failure();
251+
252+
Attribute operandAttrs[2] = {lhsAttr, rhsAttr};
253+
TypedAttr res = cast_or_null<TypedAttr>(
254+
constFoldBinaryOp<FloatAttr, FloatAttr::ValueType, void, IntegerAttr>(
255+
operandAttrs, op.getType(), [](APFloat lhs, APFloat rhs) -> APInt {
256+
return APInt(1, lhs < rhs);
257+
}));
258+
if (!res)
259+
return failure();
260+
261+
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, res);
262+
return success();
263+
}
264+
};
265+
205266
/// This pattern moves "test.move_before_parent_op" before the parent op.
206267
struct MoveBeforeParentOp : public RewritePattern {
207268
MoveBeforeParentOp(MLIRContext *context)
@@ -2181,6 +2242,24 @@ struct TestSelectiveReplacementPatternDriver
21812242
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
21822243
}
21832244
};
2245+
2246+
struct TestFoldTypeConvertingOp
2247+
: public PassWrapper<TestFoldTypeConvertingOp, OperationPass<>> {
2248+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFoldTypeConvertingOp)
2249+
2250+
StringRef getArgument() const final { return "test-fold-type-converting-op"; }
2251+
StringRef getDescription() const final {
2252+
return "Test helper functions for folding ops whose input and output types "
2253+
"differ, e.g. float comparisons of the form `(f32, f32) -> i1`.";
2254+
}
2255+
void runOnOperation() override {
2256+
MLIRContext *context = &getContext();
2257+
mlir::RewritePatternSet patterns(context);
2258+
patterns.add<FoldSignOpF32ToSI32, FoldLessThanOpF32ToI1>(context);
2259+
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
2260+
signalPassFailure();
2261+
}
2262+
};
21842263
} // namespace
21852264

21862265
//===----------------------------------------------------------------------===//
@@ -2211,6 +2290,8 @@ void registerPatternsTestPass() {
22112290

22122291
PassRegistration<TestMergeBlocksPatternDriver>();
22132292
PassRegistration<TestSelectiveReplacementPatternDriver>();
2293+
2294+
PassRegistration<TestFoldTypeConvertingOp>();
22142295
}
22152296
} // namespace test
22162297
} // namespace mlir

mlir/unittests/Dialect/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
add_mlir_unittest(MLIRDialectTests
22
BroadcastShapeTest.cpp
3-
CommonFoldersTest.cpp
43
)
54
mlir_target_link_libraries(MLIRDialectTests
65
PRIVATE

mlir/unittests/Dialect/CommonFoldersTest.cpp

Lines changed: 0 additions & 61 deletions
This file was deleted.

0 commit comments

Comments
 (0)