Skip to content

Commit 6f65e77

Browse files
committed
♻️ Adapt valueToDouble to handle implicit conversion from integers
1 parent 6aefb99 commit 6f65e77

File tree

2 files changed

+99
-54
lines changed

2 files changed

+99
-54
lines changed

mlir/include/mlir/Dialect/Utils/Utils.h

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,22 +42,31 @@ variantToValue(OpBuilder& builder, const OperationState& state,
4242
}
4343

4444
/**
45-
* @brief Try to convert a float mlir::Value to a standard C++ double
45+
* @brief Try to convert a mlir::Value to a standard C++ double
4646
*
4747
* @details
4848
* Resolving the mlir::Value will only work if it is a static value, so a value
49-
* of type float and defined via a "arith.constant" operation.
49+
* defined via a "arith.constant" operation. It must also be of type
50+
* float or integer.
5051
*/
5152
[[nodiscard]] inline std::optional<double> valueToDouble(Value value) {
5253
auto constantOp = value.getDefiningOp<arith::ConstantOp>();
5354
if (!constantOp) {
5455
return std::nullopt;
5556
}
5657
auto&& floatAttr = dyn_cast<FloatAttr>(constantOp.getValue());
57-
if (!floatAttr) {
58-
return std::nullopt;
58+
if (floatAttr) {
59+
return floatAttr.getValueAsDouble();
60+
}
61+
auto&& intAttr = dyn_cast<IntegerAttr>(constantOp.getValue());
62+
if (intAttr) {
63+
if (intAttr.getType().isUnsignedInteger()) {
64+
return static_cast<double>(intAttr.getValue().getZExtValue());
65+
}
66+
// interpret both signed+signless as signed integers
67+
return static_cast<double>(intAttr.getValue().getSExtValue());
5968
}
60-
return floatAttr.getValueAsDouble();
69+
return std::nullopt;
6170
}
6271

6372
} // namespace mlir::utils

mlir/unittests/Dialect/Utils/test_utils.cpp

Lines changed: 85 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -10,77 +10,113 @@
1010

1111
#include "mlir/Dialect/Utils/Utils.h"
1212

13+
#include <cstdint>
1314
#include <gtest/gtest.h>
15+
#include <limits>
16+
#include <llvm/ADT/APInt.h>
17+
#include <llvm/ADT/SmallVector.h>
18+
#include <memory>
1419
#include <mlir/Dialect/Arith/IR/Arith.h>
1520
#include <mlir/Dialect/Func/IR/FuncOps.h>
16-
#include <mlir/IR/BuiltinOps.h>
21+
#include <mlir/IR/Builders.h>
1722
#include <mlir/IR/MLIRContext.h>
18-
#include <mlir/Parser/Parser.h>
23+
#include <mlir/IR/Operation.h>
24+
#include <mlir/IR/Value.h>
1925

2026
using namespace mlir;
2127

2228
class UtilsTest : public ::testing::Test {
2329
protected:
2430
MLIRContext context;
31+
std::unique_ptr<OpBuilder> builder;
2532

2633
void SetUp() override {
2734
context.loadDialect<func::FuncDialect>();
2835
context.loadDialect<arith::ArithDialect>();
36+
37+
builder = std::make_unique<OpBuilder>(&context);
38+
}
39+
40+
arith::AddFOp createAddition(double a, double b) {
41+
auto firstOperand = builder->create<arith::ConstantOp>(
42+
builder->getUnknownLoc(), builder->getF64FloatAttr(a));
43+
auto secondOperand = builder->create<arith::ConstantOp>(
44+
builder->getUnknownLoc(), builder->getF64FloatAttr(b));
45+
return builder->create<arith::AddFOp>(builder->getUnknownLoc(),
46+
firstOperand, secondOperand);
2947
}
3048
};
3149

3250
TEST_F(UtilsTest, valueToDouble) {
33-
auto moduleOp = parseSourceString<ModuleOp>(
34-
"func.func @test() { arith.constant 1.234 : f64\n return }", &context);
35-
ASSERT_TRUE(moduleOp);
36-
37-
for (auto&& funcOp : moduleOp->getOps<func::FuncOp>()) {
38-
for (auto&& constantOp : funcOp.getOps<arith::ConstantOp>()) {
39-
auto value = constantOp.getResult();
40-
auto stdValue = utils::valueToDouble(value);
41-
ASSERT_TRUE(stdValue.has_value());
42-
EXPECT_DOUBLE_EQ(stdValue.value(), 1.234);
43-
return;
44-
}
45-
FAIL() << "No arith::ConstantOp found in function!";
46-
}
47-
FAIL() << "No func::FuncOp found in module!";
51+
constexpr double expectedValue = 1.234;
52+
auto op = builder->create<arith::ConstantOp>(
53+
builder->getUnknownLoc(), builder->getF64FloatAttr(expectedValue));
54+
ASSERT_TRUE(op);
55+
56+
auto value = op.getResult();
57+
auto stdValue = utils::valueToDouble(value);
58+
ASSERT_TRUE(stdValue.has_value());
59+
EXPECT_DOUBLE_EQ(stdValue.value(), expectedValue);
60+
}
61+
62+
TEST_F(UtilsTest, valueToDoubleCastFromIntegerType) {
63+
constexpr int expectedValue = 42;
64+
auto op = builder->create<arith::ConstantOp>(
65+
builder->getUnknownLoc(), builder->getI32IntegerAttr(expectedValue));
66+
ASSERT_TRUE(op);
67+
68+
auto value = op.getResult();
69+
auto stdValue = utils::valueToDouble(value);
70+
ASSERT_TRUE(stdValue.has_value());
71+
EXPECT_DOUBLE_EQ(stdValue.value(), expectedValue);
72+
}
73+
74+
TEST_F(UtilsTest, valueToDoubleCastFromMaxUnsignedInteger) {
75+
constexpr auto expectedValue = std::numeric_limits<uint64_t>::max();
76+
constexpr auto bitCount = 64;
77+
auto op = builder->create<arith::ConstantOp>(
78+
builder->getUnknownLoc(),
79+
builder->getIntegerAttr(builder->getIntegerType(bitCount, false),
80+
llvm::APInt::getMaxValue(bitCount)));
81+
ASSERT_TRUE(op);
82+
83+
auto value = op.getResult();
84+
auto stdValue = utils::valueToDouble(value);
85+
ASSERT_TRUE(stdValue.has_value());
86+
// cast to double will lose precision, but difference to maximum value of
87+
// int64_t is large enough that the check still makes sense
88+
EXPECT_DOUBLE_EQ(stdValue.value(), static_cast<double>(expectedValue));
4889
}
4990

5091
TEST_F(UtilsTest, valueToDoubleWrongType) {
51-
auto moduleOp = parseSourceString<ModuleOp>(
52-
"func.func @test() { arith.constant 42 : i32\n return }", &context);
53-
ASSERT_TRUE(moduleOp);
54-
55-
for (auto&& funcOp : moduleOp->getOps<func::FuncOp>()) {
56-
for (auto&& constantOp : funcOp.getOps<arith::ConstantOp>()) {
57-
auto value = constantOp.getResult();
58-
auto stdValue = utils::valueToDouble(value);
59-
EXPECT_FALSE(stdValue.has_value());
60-
return;
61-
}
62-
FAIL() << "No arith::ConstantOp found in function!";
63-
}
64-
FAIL() << "No func::FuncOp found in module!";
92+
auto op = builder->create<arith::ConstantOp>(builder->getUnknownLoc(),
93+
builder->getStringAttr("test"));
94+
ASSERT_TRUE(op);
95+
96+
auto value = op.getResult();
97+
auto stdValue = utils::valueToDouble(value);
98+
EXPECT_FALSE(stdValue.has_value());
6599
}
66100

67101
TEST_F(UtilsTest, valueToDoubleNonStaticValue) {
68-
auto moduleOp = parseSourceString<ModuleOp>("func.func @test() {\n"
69-
"%0 = arith.constant 1.1 : f64\n"
70-
"%1 = arith.constant 2.2 : f64\n"
71-
"arith.addf %0, %1 : f64\n"
72-
"return }",
73-
&context);
74-
ASSERT_TRUE(moduleOp);
75-
76-
for (auto&& funcOp : moduleOp->getOps<func::FuncOp>()) {
77-
for (auto&& addOp : funcOp.getOps<arith::AddFOp>()) {
78-
auto value = addOp.getResult();
79-
auto stdValue = utils::valueToDouble(value);
80-
EXPECT_FALSE(stdValue.has_value());
81-
return;
82-
}
83-
FAIL() << "No arith::AddFOp found in function!";
84-
}
85-
FAIL() << "No func::FuncOp found in module!";
102+
auto op = createAddition(9.5, 21.5);
103+
ASSERT_TRUE(op);
104+
105+
auto value = op.getResult();
106+
auto stdValue = utils::valueToDouble(value);
107+
EXPECT_FALSE(stdValue.has_value());
108+
}
109+
110+
TEST_F(UtilsTest, valueToDoubleNonStaticValueAfterFolding) {
111+
auto op = createAddition(1.1, 2.2);
112+
ASSERT_TRUE(op);
113+
114+
llvm::SmallVector<Value> tmp;
115+
llvm::SmallVector<Operation*> newConstants;
116+
ASSERT_TRUE(builder->tryFold(op, tmp, &newConstants).succeeded());
117+
ASSERT_EQ(newConstants.size(), 1);
118+
auto value = newConstants[0]->getResult(0);
119+
auto stdValue = utils::valueToDouble(value);
120+
ASSERT_TRUE(stdValue.has_value());
121+
EXPECT_DOUBLE_EQ(stdValue.value(), 3.3);
86122
}

0 commit comments

Comments
 (0)