Skip to content

Commit 4bdd204

Browse files
committed
Add arith.constant support.
Handles scalar and vector.
1 parent 4b6d70d commit 4bdd204

File tree

1 file changed

+26
-4
lines changed

1 file changed

+26
-4
lines changed

mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,14 @@ static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
9999
return builder.getF32FloatAttr(dstVal.convertToFloat());
100100
}
101101

102+
// Get IntegerAttr from FloatAttr.
103+
IntegerAttr getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType,
104+
ConversionPatternRewriter &rewriter) {
105+
APFloat floatVal = floatAttr.getValue();
106+
APInt intVal = floatVal.bitcastToAPInt();
107+
return rewriter.getIntegerAttr(dstType, intVal);
108+
}
109+
102110
/// Returns true if the given `type` is a boolean scalar or vector type.
103111
static bool isBoolScalarOrVector(Type type) {
104112
assert(type && "Not a valid type");
@@ -296,8 +304,16 @@ struct ConstantCompositeOpPattern final
296304
SmallVector<Attribute, 8> elements;
297305
if (isa<FloatType>(srcElemType)) {
298306
for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
299-
FloatAttr dstAttr =
300-
convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), rewriter);
307+
Attribute dstAttr = nullptr;
308+
// Handle 8-bit float conversion to 8-bit integer.
309+
if (srcElemType.getIntOrFloatBitWidth() == 8 &&
310+
isa<IntegerType>(dstElemType)) {
311+
dstAttr =
312+
getIntegerAttrFromFloatAttr(srcAttr, dstElemType, rewriter);
313+
} else {
314+
dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstElemType),
315+
rewriter);
316+
}
301317
if (!dstAttr)
302318
return failure();
303319
elements.push_back(dstAttr);
@@ -361,11 +377,17 @@ struct ConstantScalarOpPattern final
361377
// Floating-point types.
362378
if (isa<FloatType>(srcType)) {
363379
auto srcAttr = cast<FloatAttr>(cstAttr);
364-
auto dstAttr = srcAttr;
380+
Attribute dstAttr = srcAttr;
365381

366382
// Floating-point types not supported in the target environment are all
367383
// converted to float type.
368-
if (srcType != dstType) {
384+
if (srcType.getIntOrFloatBitWidth() == 8 && isa<IntegerType>(dstType) &&
385+
dstType.getIntOrFloatBitWidth() == 8) {
386+
// If the source is an 8-bit float, convert it to a 8-bit integer.
387+
dstAttr = getIntegerAttrFromFloatAttr(srcAttr, dstType, rewriter);
388+
if (!dstAttr)
389+
return failure();
390+
} else if (srcType != dstType) {
369391
dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter);
370392
if (!dstAttr)
371393
return failure();

0 commit comments

Comments
 (0)