Skip to content

Commit 4b08b0b

Browse files
committed
Allow constFoldBinaryOp to fold (T1, T1) -> T2
The `constFoldBinaryOp` helper function had limited support for different input and output types, but the static type of the underlying value (e.g. `APInt`) had to match between the inputs and the output. This worked fine for int comparisons of the form `(intN, intN) -> int1`, as the static type signature was `(APInt, APInt) -> APInt`. However, float comparisons map `(floatN, floatN) -> int1`, with a static type signature of `(APFloat, APFloat) -> APInt`. This use case wasn't supported by `constFoldBinaryOp`. `constFoldBinaryOp` now accepts an optional template argument overriding the return type in case it differs from the input type. If the new template argument isn't provided, the default behavior is unchanged (i.e. the return type will be assumed to match the input type).
1 parent 2e3fd54 commit 4b08b0b

File tree

3 files changed

+96
-19
lines changed

3 files changed

+96
-19
lines changed

mlir/include/mlir/Dialect/CommonFolders.h

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,16 @@
1515
#ifndef MLIR_DIALECT_COMMONFOLDERS_H
1616
#define MLIR_DIALECT_COMMONFOLDERS_H
1717

18+
#include "mlir/IR/Attributes.h"
19+
#include "mlir/IR/BuiltinAttributeInterfaces.h"
1820
#include "mlir/IR/BuiltinAttributes.h"
19-
#include "mlir/IR/BuiltinTypes.h"
21+
#include "mlir/IR/BuiltinTypeInterfaces.h"
22+
#include "mlir/IR/Types.h"
2023
#include "llvm/ADT/ArrayRef.h"
2124
#include "llvm/ADT/STLExtras.h"
25+
26+
#include <cassert>
27+
#include <cstddef>
2228
#include <optional>
2329

2430
namespace mlir {
@@ -30,11 +36,13 @@ class PoisonAttr;
3036
/// Uses `resultType` for the type of the returned attribute.
3137
/// Optional PoisonAttr template argument allows to specify 'poison' attribute
3238
/// which will be directly propagated to result.
33-
template <class AttrElementT,
39+
template <class AttrElementT, //
3440
class ElementValueT = typename AttrElementT::ValueType,
3541
class PoisonAttr = ub::PoisonAttr,
42+
class ResultAttrElementT = AttrElementT,
43+
class ResultElementValueT = typename ResultAttrElementT::ValueType,
3644
class CalculationT = function_ref<
37-
std::optional<ElementValueT>(ElementValueT, ElementValueT)>>
45+
std::optional<ResultElementValueT>(ElementValueT, ElementValueT)>>
3846
Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
3947
Type resultType,
4048
CalculationT &&calculate) {
@@ -65,7 +73,7 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
6573
if (!calRes)
6674
return {};
6775

68-
return AttrElementT::get(resultType, *calRes);
76+
return ResultAttrElementT::get(resultType, *calRes);
6977
}
7078

7179
if (isa<SplatElementsAttr>(operands[0]) &&
@@ -99,7 +107,7 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
99107
return {};
100108
auto lhsIt = *maybeLhsIt;
101109
auto rhsIt = *maybeRhsIt;
102-
SmallVector<ElementValueT, 4> elementResults;
110+
SmallVector<ResultElementValueT, 4> elementResults;
103111
elementResults.reserve(lhs.getNumElements());
104112
for (size_t i = 0, e = lhs.getNumElements(); i < e; ++i, ++lhsIt, ++rhsIt) {
105113
auto elementResult = calculate(*lhsIt, *rhsIt);
@@ -119,11 +127,13 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
119127
/// attribute.
120128
/// Optional PoisonAttr template argument allows to specify 'poison' attribute
121129
/// which will be directly propagated to result.
122-
template <class AttrElementT,
130+
template <class AttrElementT, //
123131
class ElementValueT = typename AttrElementT::ValueType,
124132
class PoisonAttr = ub::PoisonAttr,
133+
class ResultAttrElementT = AttrElementT,
134+
class ResultElementValueT = typename ResultAttrElementT::ValueType,
125135
class CalculationT = function_ref<
126-
std::optional<ElementValueT>(ElementValueT, ElementValueT)>>
136+
std::optional<ResultElementValueT>(ElementValueT, ElementValueT)>>
127137
Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
128138
CalculationT &&calculate) {
129139
assert(operands.size() == 2 && "binary op takes two operands");
@@ -153,36 +163,41 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
153163
return {};
154164

155165
return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
166+
ResultAttrElementT, ResultElementValueT,
156167
CalculationT>(
157168
operands, lhsType, std::forward<CalculationT>(calculate));
158169
}
159170

160171
template <class AttrElementT,
161172
class ElementValueT = typename AttrElementT::ValueType,
162-
class PoisonAttr = void,
173+
class PoisonAttr = void, //
174+
class ResultAttrElementT = AttrElementT,
175+
class ResultElementValueT = typename ResultAttrElementT::ValueType,
163176
class CalculationT =
164-
function_ref<ElementValueT(ElementValueT, ElementValueT)>>
177+
function_ref<ResultElementValueT(ElementValueT, ElementValueT)>>
165178
Attribute constFoldBinaryOp(ArrayRef<Attribute> operands, Type resultType,
166179
CalculationT &&calculate) {
167-
return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
180+
return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
181+
ResultAttrElementT>(
168182
operands, resultType,
169-
[&](ElementValueT a, ElementValueT b) -> std::optional<ElementValueT> {
170-
return calculate(a, b);
171-
});
183+
[&](ElementValueT a, ElementValueT b)
184+
-> std::optional<ResultElementValueT> { return calculate(a, b); });
172185
}
173186

174-
template <class AttrElementT,
187+
template <class AttrElementT, //
175188
class ElementValueT = typename AttrElementT::ValueType,
176189
class PoisonAttr = ub::PoisonAttr,
190+
class ResultAttrElementT = AttrElementT,
191+
class ResultElementValueT = typename ResultAttrElementT::ValueType,
177192
class CalculationT =
178-
function_ref<ElementValueT(ElementValueT, ElementValueT)>>
193+
function_ref<ResultElementValueT(ElementValueT, ElementValueT)>>
179194
Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
180195
CalculationT &&calculate) {
181-
return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
196+
return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
197+
ResultAttrElementT>(
182198
operands,
183-
[&](ElementValueT a, ElementValueT b) -> std::optional<ElementValueT> {
184-
return calculate(a, b);
185-
});
199+
[&](ElementValueT a, ElementValueT b)
200+
-> std::optional<ResultElementValueT> { return calculate(a, b); });
186201
}
187202

188203
/// Performs constant folding `calculate` with element-wise behavior on the one

mlir/unittests/Dialect/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
add_mlir_unittest(MLIRDialectTests
22
BroadcastShapeTest.cpp
3+
CommonFoldersTest.cpp
34
)
45
mlir_target_link_libraries(MLIRDialectTests
56
PRIVATE
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
//===- CommonFoldersTest.cpp - tests for folder-pattern helper templates --===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/CommonFolders.h"
10+
#include "mlir/IR/BuiltinAttributes.h"
11+
#include "mlir/IR/BuiltinTypeInterfaces.h"
12+
#include "mlir/IR/BuiltinTypes.h"
13+
#include "mlir/IR/MLIRContext.h"
14+
#include "llvm/ADT/APFloat.h"
15+
#include "llvm/ADT/APInt.h"
16+
#include "llvm/Support/Casting.h"
17+
#include "gmock/gmock.h"
18+
#include "gtest/gtest.h"
19+
20+
namespace mlir {
21+
namespace {
22+
23+
using ::llvm::APFloat;
24+
using ::llvm::APInt;
25+
using ::mlir::constFoldBinaryOp;
26+
using ::mlir::DenseElementsAttr;
27+
using ::mlir::Float32Type;
28+
using ::mlir::FloatAttr;
29+
using ::mlir::IntegerAttr;
30+
using ::mlir::IntegerType;
31+
using ::mlir::MLIRContext;
32+
using ::mlir::RankedTensorType;
33+
using ::testing::ElementsAre;
34+
35+
APInt floatLessThan(APFloat lhs, APFloat rhs) { return APInt(1, lhs < rhs); }
36+
37+
TEST(CommonFoldersTest, FoldFloatComparisonToBoolean) {
38+
MLIRContext context;
39+
auto vector4xf32 = RankedTensorType::get({4}, Float32Type::get(&context));
40+
41+
auto lhs = DenseElementsAttr::get(vector4xf32, {-12.9f, 0.0f, 42.5f, -0.01f});
42+
auto rhs = DenseElementsAttr::get(vector4xf32, {0.0f, 0.0f, 0.0f, 0.0f});
43+
44+
auto result = llvm::dyn_cast<DenseElementsAttr>(
45+
constFoldBinaryOp<FloatAttr, FloatAttr::ValueType, void, IntegerAttr>(
46+
{lhs, rhs}, RankedTensorType::get({4}, IntegerType::get(&context, 1)),
47+
floatLessThan));
48+
ASSERT_TRUE(result);
49+
50+
auto resultElementType = result.getElementType();
51+
EXPECT_TRUE(resultElementType.isInteger(1));
52+
53+
const APInt i1True = APInt(1, true);
54+
const APInt i1False = APInt(1, false);
55+
56+
EXPECT_THAT(result.getValues<APInt>(),
57+
ElementsAre(i1True, i1False, i1False, i1True));
58+
}
59+
60+
} // namespace
61+
} // namespace mlir

0 commit comments

Comments
 (0)