Skip to content

Commit de2bac3

Browse files
authored
[MLIR] Allow constFoldBinaryOp to fold (T1, T1) -> T2 (#151410)
The `constFoldBinaryOp` helper function had limited support for different input and output types, but the static type of the underlying value (e.g. `APInt`) had to match between the inputs and the output. This worked fine for int comparisons of the form `(intN, intN) -> int1`, as the static type signature was `(APInt, APInt) -> APInt`. However, float comparisons map `(floatN, floatN) -> int1`, with a static type signature of `(APFloat, APFloat) -> APInt`. This use case wasn't supported by `constFoldBinaryOp`. `constFoldBinaryOp` now accepts an optional template argument overriding the return type in case it differs from the input type. If the new template argument isn't provided, the default behavior is unchanged (i.e. the return type will be assumed to match the input type). `constFoldUnaryOp` received similar changes in order to support folding non-cast ops of the form `(T1) -> T2` (e.g. a `sign` op mapping `(floatN) -> sint32`).
1 parent 6f272d1 commit de2bac3

File tree

4 files changed

+237
-33
lines changed

4 files changed

+237
-33
lines changed

mlir/include/mlir/Dialect/CommonFolders.h

Lines changed: 114 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,16 @@
1515
#ifndef MLIR_DIALECT_COMMONFOLDERS_H
1616
#define MLIR_DIALECT_COMMONFOLDERS_H
1717

18+
#include "mlir/IR/Attributes.h"
19+
#include "mlir/IR/BuiltinAttributeInterfaces.h"
1820
#include "mlir/IR/BuiltinAttributes.h"
19-
#include "mlir/IR/BuiltinTypes.h"
21+
#include "mlir/IR/BuiltinTypeInterfaces.h"
22+
#include "mlir/IR/Types.h"
2023
#include "llvm/ADT/ArrayRef.h"
2124
#include "llvm/ADT/STLExtras.h"
25+
26+
#include <cassert>
27+
#include <cstddef>
2228
#include <optional>
2329

2430
namespace mlir {
@@ -30,11 +36,13 @@ class PoisonAttr;
3036
/// Uses `resultType` for the type of the returned attribute.
3137
/// Optional PoisonAttr template argument allows to specify 'poison' attribute
3238
/// which will be directly propagated to result.
33-
template <class AttrElementT,
39+
template <class AttrElementT, //
3440
class ElementValueT = typename AttrElementT::ValueType,
3541
class PoisonAttr = ub::PoisonAttr,
42+
class ResultAttrElementT = AttrElementT,
43+
class ResultElementValueT = typename ResultAttrElementT::ValueType,
3644
class CalculationT = function_ref<
37-
std::optional<ElementValueT>(ElementValueT, ElementValueT)>>
45+
std::optional<ResultElementValueT>(ElementValueT, ElementValueT)>>
3846
Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
3947
Type resultType,
4048
CalculationT &&calculate) {
@@ -65,7 +73,7 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
6573
if (!calRes)
6674
return {};
6775

68-
return AttrElementT::get(resultType, *calRes);
76+
return ResultAttrElementT::get(resultType, *calRes);
6977
}
7078

7179
if (isa<SplatElementsAttr>(operands[0]) &&
@@ -99,7 +107,7 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
99107
return {};
100108
auto lhsIt = *maybeLhsIt;
101109
auto rhsIt = *maybeRhsIt;
102-
SmallVector<ElementValueT, 4> elementResults;
110+
SmallVector<ResultElementValueT, 4> elementResults;
103111
elementResults.reserve(lhs.getNumElements());
104112
for (size_t i = 0, e = lhs.getNumElements(); i < e; ++i, ++lhsIt, ++rhsIt) {
105113
auto elementResult = calculate(*lhsIt, *rhsIt);
@@ -119,11 +127,13 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
119127
/// attribute.
120128
/// Optional PoisonAttr template argument allows to specify 'poison' attribute
121129
/// which will be directly propagated to result.
122-
template <class AttrElementT,
130+
template <class AttrElementT, //
123131
class ElementValueT = typename AttrElementT::ValueType,
124132
class PoisonAttr = ub::PoisonAttr,
133+
class ResultAttrElementT = AttrElementT,
134+
class ResultElementValueT = typename ResultAttrElementT::ValueType,
125135
class CalculationT = function_ref<
126-
std::optional<ElementValueT>(ElementValueT, ElementValueT)>>
136+
std::optional<ResultElementValueT>(ElementValueT, ElementValueT)>>
127137
Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
128138
CalculationT &&calculate) {
129139
assert(operands.size() == 2 && "binary op takes two operands");
@@ -139,64 +149,73 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
139149
return operands[1];
140150
}
141151

142-
auto getResultType = [](Attribute attr) -> Type {
152+
auto getAttrType = [](Attribute attr) -> Type {
143153
if (auto typed = dyn_cast_or_null<TypedAttr>(attr))
144154
return typed.getType();
145155
return {};
146156
};
147157

148-
Type lhsType = getResultType(operands[0]);
149-
Type rhsType = getResultType(operands[1]);
158+
Type lhsType = getAttrType(operands[0]);
159+
Type rhsType = getAttrType(operands[1]);
150160
if (!lhsType || !rhsType)
151161
return {};
152162
if (lhsType != rhsType)
153163
return {};
154164

155165
return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
166+
ResultAttrElementT, ResultElementValueT,
156167
CalculationT>(
157168
operands, lhsType, std::forward<CalculationT>(calculate));
158169
}
159170

160171
template <class AttrElementT,
161172
class ElementValueT = typename AttrElementT::ValueType,
162-
class PoisonAttr = void,
173+
class PoisonAttr = void, //
174+
class ResultAttrElementT = AttrElementT,
175+
class ResultElementValueT = typename ResultAttrElementT::ValueType,
163176
class CalculationT =
164-
function_ref<ElementValueT(ElementValueT, ElementValueT)>>
177+
function_ref<ResultElementValueT(ElementValueT, ElementValueT)>>
165178
Attribute constFoldBinaryOp(ArrayRef<Attribute> operands, Type resultType,
166179
CalculationT &&calculate) {
167-
return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
180+
return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
181+
ResultAttrElementT>(
168182
operands, resultType,
169-
[&](ElementValueT a, ElementValueT b) -> std::optional<ElementValueT> {
170-
return calculate(a, b);
171-
});
183+
[&](ElementValueT a, ElementValueT b)
184+
-> std::optional<ResultElementValueT> { return calculate(a, b); });
172185
}
173186

174-
template <class AttrElementT,
187+
template <class AttrElementT, //
175188
class ElementValueT = typename AttrElementT::ValueType,
176189
class PoisonAttr = ub::PoisonAttr,
190+
class ResultAttrElementT = AttrElementT,
191+
class ResultElementValueT = typename ResultAttrElementT::ValueType,
177192
class CalculationT =
178-
function_ref<ElementValueT(ElementValueT, ElementValueT)>>
193+
function_ref<ResultElementValueT(ElementValueT, ElementValueT)>>
179194
Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
180195
CalculationT &&calculate) {
181-
return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
196+
return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
197+
ResultAttrElementT>(
182198
operands,
183-
[&](ElementValueT a, ElementValueT b) -> std::optional<ElementValueT> {
184-
return calculate(a, b);
185-
});
199+
[&](ElementValueT a, ElementValueT b)
200+
-> std::optional<ResultElementValueT> { return calculate(a, b); });
186201
}
187202

188203
/// Performs constant folding `calculate` with element-wise behavior on the one
189204
/// attributes in `operands` and returns the result if possible.
205+
/// Uses `resultType` for the type of the returned attribute.
190206
/// Optional PoisonAttr template argument allows to specify 'poison' attribute
191207
/// which will be directly propagated to result.
192-
template <class AttrElementT,
208+
template <class AttrElementT, //
193209
class ElementValueT = typename AttrElementT::ValueType,
194210
class PoisonAttr = ub::PoisonAttr,
211+
class ResultAttrElementT = AttrElementT,
212+
class ResultElementValueT = typename ResultAttrElementT::ValueType,
195213
class CalculationT =
196-
function_ref<std::optional<ElementValueT>(ElementValueT)>>
214+
function_ref<std::optional<ResultElementValueT>(ElementValueT)>>
197215
Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
216+
Type resultType,
198217
CalculationT &&calculate) {
199-
if (!llvm::getSingleElement(operands))
218+
if (!resultType || !llvm::getSingleElement(operands))
200219
return {};
201220

202221
static_assert(
@@ -214,7 +233,7 @@ Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
214233
auto res = calculate(op.getValue());
215234
if (!res)
216235
return {};
217-
return AttrElementT::get(op.getType(), *res);
236+
return ResultAttrElementT::get(resultType, *res);
218237
}
219238
if (isa<SplatElementsAttr>(operands[0])) {
220239
// Both operands are splats so we can avoid expanding the values out and
@@ -224,7 +243,7 @@ Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
224243
auto elementResult = calculate(op.getSplatValue<ElementValueT>());
225244
if (!elementResult)
226245
return {};
227-
return DenseElementsAttr::get(op.getType(), *elementResult);
246+
return DenseElementsAttr::get(cast<ShapedType>(resultType), *elementResult);
228247
} else if (isa<ElementsAttr>(operands[0])) {
229248
// Operands are ElementsAttr-derived; perform an element-wise fold by
230249
// expanding the values.
@@ -234,27 +253,89 @@ Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
234253
if (!maybeOpIt)
235254
return {};
236255
auto opIt = *maybeOpIt;
237-
SmallVector<ElementValueT> elementResults;
256+
SmallVector<ResultElementValueT> elementResults;
238257
elementResults.reserve(op.getNumElements());
239258
for (size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt) {
240259
auto elementResult = calculate(*opIt);
241260
if (!elementResult)
242261
return {};
243262
elementResults.push_back(*elementResult);
244263
}
245-
return DenseElementsAttr::get(op.getShapedType(), elementResults);
264+
return DenseElementsAttr::get(cast<ShapedType>(resultType), elementResults);
246265
}
247266
return {};
248267
}
249268

250-
template <class AttrElementT,
269+
/// Performs constant folding `calculate` with element-wise behavior on the one
270+
/// attributes in `operands` and returns the result if possible.
271+
/// Uses the operand element type for the element type of the returned
272+
/// attribute.
273+
/// Optional PoisonAttr template argument allows to specify 'poison' attribute
274+
/// which will be directly propagated to result.
275+
template <class AttrElementT, //
276+
class ElementValueT = typename AttrElementT::ValueType,
277+
class PoisonAttr = ub::PoisonAttr,
278+
class ResultAttrElementT = AttrElementT,
279+
class ResultElementValueT = typename ResultAttrElementT::ValueType,
280+
class CalculationT =
281+
function_ref<std::optional<ResultElementValueT>(ElementValueT)>>
282+
Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
283+
CalculationT &&calculate) {
284+
if (!llvm::getSingleElement(operands))
285+
return {};
286+
287+
static_assert(
288+
std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
289+
"PoisonAttr is undefined, either add a dependency on UB dialect or pass "
290+
"void as template argument to opt-out from poison semantics.");
291+
if constexpr (!std::is_void_v<PoisonAttr>) {
292+
if (isa<PoisonAttr>(operands[0]))
293+
return operands[0];
294+
}
295+
296+
auto getAttrType = [](Attribute attr) -> Type {
297+
if (auto typed = dyn_cast_or_null<TypedAttr>(attr))
298+
return typed.getType();
299+
return {};
300+
};
301+
302+
Type operandType = getAttrType(operands[0]);
303+
if (!operandType)
304+
return {};
305+
306+
return constFoldUnaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
307+
ResultAttrElementT, ResultElementValueT,
308+
CalculationT>(
309+
operands, operandType, std::forward<CalculationT>(calculate));
310+
}
311+
312+
template <class AttrElementT, //
313+
class ElementValueT = typename AttrElementT::ValueType,
314+
class PoisonAttr = ub::PoisonAttr,
315+
class ResultAttrElementT = AttrElementT,
316+
class ResultElementValueT = typename ResultAttrElementT::ValueType,
317+
class CalculationT = function_ref<ResultElementValueT(ElementValueT)>>
318+
Attribute constFoldUnaryOp(ArrayRef<Attribute> operands, Type resultType,
319+
CalculationT &&calculate) {
320+
return constFoldUnaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
321+
ResultAttrElementT>(
322+
operands, resultType,
323+
[&](ElementValueT a) -> std::optional<ResultElementValueT> {
324+
return calculate(a);
325+
});
326+
}
327+
328+
template <class AttrElementT, //
251329
class ElementValueT = typename AttrElementT::ValueType,
252330
class PoisonAttr = ub::PoisonAttr,
253-
class CalculationT = function_ref<ElementValueT(ElementValueT)>>
331+
class ResultAttrElementT = AttrElementT,
332+
class ResultElementValueT = typename ResultAttrElementT::ValueType,
333+
class CalculationT = function_ref<ResultElementValueT(ElementValueT)>>
254334
Attribute constFoldUnaryOp(ArrayRef<Attribute> operands,
255335
CalculationT &&calculate) {
256-
return constFoldUnaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
257-
operands, [&](ElementValueT a) -> std::optional<ElementValueT> {
336+
return constFoldUnaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
337+
ResultAttrElementT>(
338+
operands, [&](ElementValueT a) -> std::optional<ResultElementValueT> {
258339
return calculate(a);
259340
});
260341
}

mlir/test/Dialect/common_folders.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// RUN: mlir-opt %s --test-fold-type-converting-op --split-input-file | FileCheck %s
2+
3+
// CHECK-LABEL: @test_fold_unary_op_f32_to_si32(
4+
func.func @test_fold_unary_op_f32_to_si32() -> tensor<4x2xsi32> {
5+
// CHECK-NEXT: %[[POSITIVE_ONE:.*]] = arith.constant dense<1> : tensor<4x2xsi32>
6+
// CHECK-NEXT: return %[[POSITIVE_ONE]] : tensor<4x2xsi32>
7+
%operand = arith.constant dense<5.1> : tensor<4x2xf32>
8+
%sign = test.sign %operand : (tensor<4x2xf32>) -> tensor<4x2xsi32>
9+
return %sign : tensor<4x2xsi32>
10+
}
11+
12+
// -----
13+
14+
// CHECK-LABEL: @test_fold_binary_op_f32_to_i1(
15+
func.func @test_fold_binary_op_f32_to_i1() -> tensor<8xi1> {
16+
// CHECK-NEXT: %[[FALSE:.*]] = arith.constant dense<false> : tensor<8xi1>
17+
// CHECK-NEXT: return %[[FALSE]] : tensor<8xi1>
18+
%lhs = arith.constant dense<5.1> : tensor<8xf32>
19+
%rhs = arith.constant dense<4.2> : tensor<8xf32>
20+
%less_than = test.less_than %lhs, %rhs : (tensor<8xf32>, tensor<8xf32>) -> tensor<8xi1>
21+
return %less_than : tensor<8xi1>
22+
}

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1169,6 +1169,26 @@ def OpP : TEST_Op<"op_p"> {
11691169
let results = (outs I32);
11701170
}
11711171

1172+
// Test constant-folding a pattern that maps `(F32) -> SI32`.
1173+
def SignOp : TEST_Op<"sign", [SameOperandsAndResultShape]> {
1174+
let arguments = (ins RankedTensorOf<[F32]>:$operand);
1175+
let results = (outs RankedTensorOf<[SI32]>:$result);
1176+
1177+
let assemblyFormat = [{
1178+
$operand attr-dict `:` functional-type(operands, results)
1179+
}];
1180+
}
1181+
1182+
// Test constant-folding a pattern that maps `(F32, F32) -> I1`.
1183+
def LessThanOp : TEST_Op<"less_than", [SameOperandsAndResultShape]> {
1184+
let arguments = (ins RankedTensorOf<[F32]>:$lhs, RankedTensorOf<[F32]>:$rhs);
1185+
let results = (outs RankedTensorOf<[I1]>:$result);
1186+
1187+
let assemblyFormat = [{
1188+
$lhs `,` $rhs attr-dict `:` functional-type(operands, results)
1189+
}];
1190+
}
1191+
11721192
// Test same operand name enforces equality condition check.
11731193
def TestEqualArgsPattern : Pat<(OpN $a, $a), (OpO $a)>;
11741194

0 commit comments

Comments
 (0)