Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 96 additions & 31 deletions mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -856,6 +856,17 @@ convertArithRoundingModeToSPIRV(arith::RoundingMode roundingMode) {
llvm_unreachable("Unhandled rounding mode");
}

static bool isSignednessCast(Type srcType, Type dstType) {
if (srcType.isInteger() && dstType.isInteger()) {
return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
}
if (isa<VectorType>(srcType) && isa<VectorType>(dstType)) {
return isSignednessCast(cast<VectorType>(srcType).getElementType(),
cast<VectorType>(dstType).getElementType());
}
return false;
}

/// Converts type-casting standard operations to SPIR-V operations.
template <typename Op, typename SPIRVOp>
struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
Expand All @@ -864,42 +875,86 @@ struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
LogicalResult
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type srcType = llvm::getSingleElement(adaptor.getOperands()).getType();
Type dstType = this->getTypeConverter()->convertType(op.getType());
if (!dstType)
return getTypeConversionFailure(rewriter, op);
TypeRange dstTypes;
SmallVector<Type> newDstTypes;
SmallVector<Value> unrealizedConvCastSrcs;
SmallVector<Type> unrealizedConvCastDstTypes;
constexpr bool isUnrealizedConvCast =
std::is_same_v<Op, UnrealizedConversionCastOp>;
if constexpr (isUnrealizedConvCast)
dstTypes = op.getOutputs().getTypes();
else
dstTypes = op.getType();
LogicalResult matched = failure();
for (auto [src, dstType] : llvm::zip(adaptor.getOperands(), dstTypes)) {
Type srcType = src.getType();
// Use UnrealizedConversionCast as the bridge so that we don't need to
// pull in patterns for other dialects.
if (isUnrealizedConvCast && !isSignednessCast(srcType, dstType)) {
newDstTypes.push_back(dstType);
unrealizedConvCastSrcs.push_back(src);
unrealizedConvCastDstTypes.push_back(dstType);
continue;
}
dstType = this->getTypeConverter()->convertType(dstType);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could it make sense to create a new variable newDstType to avoid confusion?

if (!dstType)
return getTypeConversionFailure(rewriter, op);

if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
return failure();
matched = success();
newDstTypes.push_back(dstType);
}

if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
if (failed(matched))
Copy link
Contributor

@fabrizio-indirli fabrizio-indirli Aug 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we get rid of the matched variable, and instead create a const bool variable after the loop to check if !newDstTypes.empty() && (newDstTypes.size() > unrealizedConvCastDstTypes.size()) here, or am I missing something?

return failure();

if (dstType == srcType) {
// Due to type conversion, we are seeing the same source and target type.
// Then we can just erase this operation by forwarding its operand.
rewriter.replaceOp(op, adaptor.getOperands().front());
} else {
// Compute new rounding mode (if any).
std::optional<spirv::FPRoundingMode> rm = std::nullopt;
if (auto roundingModeOp =
dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
if (arith::RoundingModeAttr roundingMode =
roundingModeOp.getRoundingModeAttr()) {
if (!(rm =
convertArithRoundingModeToSPIRV(roundingMode.getValue()))) {
return rewriter.notifyMatchFailure(
op->getLoc(),
llvm::formatv("unsupported rounding mode '{0}'", roundingMode));
}
// Compute new rounding mode (if any).
Location loc = op->getLoc();
std::optional<spirv::FPRoundingMode> rm = std::nullopt;
if (auto roundingModeOp =
dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
if (arith::RoundingModeAttr roundingMode =
roundingModeOp.getRoundingModeAttr()) {
if (!(rm = convertArithRoundingModeToSPIRV(roundingMode.getValue()))) {
return rewriter.notifyMatchFailure(
loc,
llvm::formatv("unsupported rounding mode '{0}'", roundingMode));
}
}
// Create replacement op and attach rounding mode attribute (if any).
auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
op, dstType, adaptor.getOperands());
if (rm) {
newOp->setAttr(
getDecorationString(spirv::Decoration::FPRoundingMode),
spirv::FPRoundingModeAttr::get(rewriter.getContext(), *rm));
}

llvm::DenseMap<Value, Value> unrealizedConvCastSrcDstMap;
Copy link
Contributor

@fabrizio-indirli fabrizio-indirli Aug 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could it make sense to add a comment explaining what's happening here? E.g. "Recreate unrealized_conversion_cast ops for unhandled casts" (if I understood the logic correctly)?

if (!unrealizedConvCastSrcs.empty()) {
auto newOp = rewriter.create<UnrealizedConversionCastOp>(
loc, unrealizedConvCastDstTypes, unrealizedConvCastSrcs);
for (auto [src, dst] :
llvm::zip(unrealizedConvCastSrcs, newOp.getResults()))
unrealizedConvCastSrcDstMap[src] = dst;
}

SmallVector<Value> newValues;
for (auto [src, dstType] : llvm::zip(adaptor.getOperands(), newDstTypes)) {
Type srcType = src.getType();
if (dstType == srcType) {
// Due to type conversion, we are seeing the same source and target
// type. Then we can just erase this operation by forwarding its
// operand.
newValues.push_back(src);
} else if (isUnrealizedConvCast && !isSignednessCast(srcType, dstType)) {
newValues.push_back(unrealizedConvCastSrcDstMap[src]);
} else {
// Create replacement op and attach rounding mode attribute (if any).
auto newOp = rewriter.template create<SPIRVOp>(loc, dstType, src);
if (rm) {
newOp->setAttr(
getDecorationString(spirv::Decoration::FPRoundingMode),
spirv::FPRoundingModeAttr::get(rewriter.getContext(), *rm));
}
newValues.push_back(newOp.getResult());
}
}
rewriter.replaceOp(op, newValues);
return success();
}
};
Expand Down Expand Up @@ -1331,6 +1386,7 @@ void mlir::arith::populateArithToSPIRVPatterns(
TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
TypeCastingOpPattern<UnrealizedConversionCastOp, spirv::BitcastOp>,
CmpIOpBooleanPattern, CmpIOpPattern,
CmpFOpNanNonePattern, CmpFOpPattern,
AddUIExtendedOpPattern,
Expand Down Expand Up @@ -1385,8 +1441,17 @@ struct ConvertArithToSPIRVPass
SPIRVTypeConverter typeConverter(targetAttr, options);

// Use UnrealizedConversionCast as the bridge so that we don't need to pull
// in patterns for other dialects.
target->addLegalOp<UnrealizedConversionCastOp>();
// in patterns for other dialects. If the UnrealizedConversionCast is
// between integers of the same bitwidth, it is either a nop or a
// signedness cast which the corresponding pattern convert to Bitcast.
target->addDynamicallyLegalOp<UnrealizedConversionCastOp>(
[&](UnrealizedConversionCastOp op) {
for (auto [srcType, dstType] :
llvm::zip(op.getOperandTypes(), op.getResultTypes()))
if (isSignednessCast(srcType, dstType))
return false;
return true;
});

// Fail hard when there are any remaining 'arith' ops.
target->addIllegalDialect<arith::ArithDialect>();
Expand Down
29 changes: 29 additions & 0 deletions mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,35 @@ func.func @bit_cast(%arg0: vector<2xf32>, %arg1: i64) {
return
}

// CHECK-LABEL: @unrealized_conversion_cast
func.func @unrealized_conversion_cast(%arg0: vector<3xi64>, %arg1: i16, %arg2: f32) {
// CHECK-NEXT: spirv.Bitcast %{{.+}} : vector<3xi64> to vector<3xui64>
%0 = builtin.unrealized_conversion_cast %arg0 : vector<3xi64> to vector<3xui64>
// CHECK-NEXT: spirv.Bitcast %{{.+}} : i16 to ui16
%1 = builtin.unrealized_conversion_cast %arg1 : i16 to ui16

// CHECK-NEXT: spirv.Bitcast %{{.+}} : vector<3xi64> to vector<3xui64>
// CHECK-NEXT: spirv.Bitcast %{{.+}} : i16 to ui16
%2:2 = builtin.unrealized_conversion_cast %arg0, %arg1 : vector<3xi64>, i16 to vector<3xui64>, ui16

// CHECK-NEXT: spirv.Bitcast %{{.+}} : i16 to ui16
%3:2 = builtin.unrealized_conversion_cast %arg0, %arg1 : vector<3xi64>, i16 to vector<3xi64>, ui16
// CHECK-NEXT: spirv.Bitcast %{{.+}} : vector<3xi64> to vector<3xui64>
%4:2 = builtin.unrealized_conversion_cast %arg0, %arg1 : vector<3xi64>, i16 to vector<3xui64>, i16

// bitcast from float to int should be represented using arith.bitcast
// CHECK-NEXT: builtin.unrealized_conversion_cast %{{.+}} : f32 to i32
%5 = builtin.unrealized_conversion_cast %arg2 : f32 to i32

// test mixed signedness and non-signedness cast
// CHECK-NEXT: builtin.unrealized_conversion_cast %{{.+}} : f32 to f16
// CHECK-NEXT: spirv.Bitcast %{{.+}} : i32 to ui32
%6:2 = builtin.unrealized_conversion_cast %5, %arg2 : i32, f32 to ui32, f16

// CHECK-NEXT: return
return
}

// CHECK-LABEL: @fpext1
func.func @fpext1(%arg0: f16) -> f64 {
// CHECK: spirv.FConvert %{{.*}} : f16 to f64
Expand Down