Skip to content

Commit 1796a67

Browse files
mplatingsRoboTux
andcommitted
[mlir][spirv] Handle signedness casts
Lower unrealized_conversion_cast of signed/unsigned/signless integer types of the same size to spirv.Bitcast. arith.bitcast is specifically for signless types, hence it is not used for such casts and unrealized_conversion_cast is used instead. Co-authored-by: Thomas Preud'homme <[email protected]>
1 parent d9cd6ed commit 1796a67

File tree

2 files changed

+125
-31
lines changed

2 files changed

+125
-31
lines changed

mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp

Lines changed: 96 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -856,6 +856,17 @@ convertArithRoundingModeToSPIRV(arith::RoundingMode roundingMode) {
856856
llvm_unreachable("Unhandled rounding mode");
857857
}
858858

859+
static bool isSignednessCast(Type srcType, Type dstType) {
860+
if (srcType.isInteger() && dstType.isInteger()) {
861+
return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
862+
}
863+
if (isa<VectorType>(srcType) && isa<VectorType>(dstType)) {
864+
return isSignednessCast(cast<VectorType>(srcType).getElementType(),
865+
cast<VectorType>(dstType).getElementType());
866+
}
867+
return false;
868+
}
869+
859870
/// Converts type-casting standard operations to SPIR-V operations.
860871
template <typename Op, typename SPIRVOp>
861872
struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
@@ -864,42 +875,86 @@ struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
864875
LogicalResult
865876
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
866877
ConversionPatternRewriter &rewriter) const override {
867-
Type srcType = llvm::getSingleElement(adaptor.getOperands()).getType();
868-
Type dstType = this->getTypeConverter()->convertType(op.getType());
869-
if (!dstType)
870-
return getTypeConversionFailure(rewriter, op);
878+
TypeRange dstTypes;
879+
SmallVector<Type> newDstTypes;
880+
SmallVector<Value> unrealizedConvCastSrcs;
881+
SmallVector<Type> unrealizedConvCastDstTypes;
882+
constexpr bool isUnrealizedConvCast =
883+
std::is_same_v<Op, UnrealizedConversionCastOp>;
884+
if constexpr (isUnrealizedConvCast)
885+
dstTypes = op.getOutputs().getTypes();
886+
else
887+
dstTypes = op.getType();
888+
LogicalResult matched = failure();
889+
for (auto [src, dstType] : llvm::zip(adaptor.getOperands(), dstTypes)) {
890+
Type srcType = src.getType();
891+
// Use UnrealizedConversionCast as the bridge so that we don't need to
892+
// pull in patterns for other dialects.
893+
if (isUnrealizedConvCast && !isSignednessCast(srcType, dstType)) {
894+
newDstTypes.push_back(dstType);
895+
unrealizedConvCastSrcs.push_back(src);
896+
unrealizedConvCastDstTypes.push_back(dstType);
897+
continue;
898+
}
899+
dstType = this->getTypeConverter()->convertType(dstType);
900+
if (!dstType)
901+
return getTypeConversionFailure(rewriter, op);
902+
903+
if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
904+
return failure();
905+
matched = success();
906+
newDstTypes.push_back(dstType);
907+
}
871908

872-
if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
909+
if (failed(matched))
873910
return failure();
874911

875-
if (dstType == srcType) {
876-
// Due to type conversion, we are seeing the same source and target type.
877-
// Then we can just erase this operation by forwarding its operand.
878-
rewriter.replaceOp(op, adaptor.getOperands().front());
879-
} else {
880-
// Compute new rounding mode (if any).
881-
std::optional<spirv::FPRoundingMode> rm = std::nullopt;
882-
if (auto roundingModeOp =
883-
dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
884-
if (arith::RoundingModeAttr roundingMode =
885-
roundingModeOp.getRoundingModeAttr()) {
886-
if (!(rm =
887-
convertArithRoundingModeToSPIRV(roundingMode.getValue()))) {
888-
return rewriter.notifyMatchFailure(
889-
op->getLoc(),
890-
llvm::formatv("unsupported rounding mode '{0}'", roundingMode));
891-
}
912+
// Compute new rounding mode (if any).
913+
Location loc = op->getLoc();
914+
std::optional<spirv::FPRoundingMode> rm = std::nullopt;
915+
if (auto roundingModeOp =
916+
dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
917+
if (arith::RoundingModeAttr roundingMode =
918+
roundingModeOp.getRoundingModeAttr()) {
919+
if (!(rm = convertArithRoundingModeToSPIRV(roundingMode.getValue()))) {
920+
return rewriter.notifyMatchFailure(
921+
loc,
922+
llvm::formatv("unsupported rounding mode '{0}'", roundingMode));
892923
}
893924
}
894-
// Create replacement op and attach rounding mode attribute (if any).
895-
auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
896-
op, dstType, adaptor.getOperands());
897-
if (rm) {
898-
newOp->setAttr(
899-
getDecorationString(spirv::Decoration::FPRoundingMode),
900-
spirv::FPRoundingModeAttr::get(rewriter.getContext(), *rm));
925+
}
926+
927+
llvm::DenseMap<Value, Value> unrealizedConvCastSrcDstMap;
928+
if (!unrealizedConvCastSrcs.empty()) {
929+
auto newOp = rewriter.create<UnrealizedConversionCastOp>(
930+
loc, unrealizedConvCastDstTypes, unrealizedConvCastSrcs);
931+
for (auto [src, dst] :
932+
llvm::zip(unrealizedConvCastSrcs, newOp.getResults()))
933+
unrealizedConvCastSrcDstMap[src] = dst;
934+
}
935+
936+
SmallVector<Value> newValues;
937+
for (auto [src, dstType] : llvm::zip(adaptor.getOperands(), newDstTypes)) {
938+
Type srcType = src.getType();
939+
if (dstType == srcType) {
940+
// Due to type conversion, we are seeing the same source and target
941+
// type. Then we can just erase this operation by forwarding its
942+
// operand.
943+
newValues.push_back(src);
944+
} else if (isUnrealizedConvCast && !isSignednessCast(srcType, dstType)) {
945+
newValues.push_back(unrealizedConvCastSrcDstMap[src]);
946+
} else {
947+
// Create replacement op and attach rounding mode attribute (if any).
948+
auto newOp = rewriter.template create<SPIRVOp>(loc, dstType, src);
949+
if (rm) {
950+
newOp->setAttr(
951+
getDecorationString(spirv::Decoration::FPRoundingMode),
952+
spirv::FPRoundingModeAttr::get(rewriter.getContext(), *rm));
953+
}
954+
newValues.push_back(newOp.getResult());
901955
}
902956
}
957+
rewriter.replaceOp(op, newValues);
903958
return success();
904959
}
905960
};
@@ -1331,6 +1386,7 @@ void mlir::arith::populateArithToSPIRVPatterns(
13311386
TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
13321387
TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
13331388
TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
1389+
TypeCastingOpPattern<UnrealizedConversionCastOp, spirv::BitcastOp>,
13341390
CmpIOpBooleanPattern, CmpIOpPattern,
13351391
CmpFOpNanNonePattern, CmpFOpPattern,
13361392
AddUIExtendedOpPattern,
@@ -1385,8 +1441,17 @@ struct ConvertArithToSPIRVPass
13851441
SPIRVTypeConverter typeConverter(targetAttr, options);
13861442

13871443
// Use UnrealizedConversionCast as the bridge so that we don't need to pull
1388-
// in patterns for other dialects.
1389-
target->addLegalOp<UnrealizedConversionCastOp>();
1444+
// in patterns for other dialects. If the UnrealizedConversionCast is
1445+
// between integers of the same bitwidth, it is either a nop or a
1446+
// signedness cast which the corresponding pattern convert to Bitcast.
1447+
target->addDynamicallyLegalOp<UnrealizedConversionCastOp>(
1448+
[&](UnrealizedConversionCastOp op) {
1449+
for (auto [srcType, dstType] :
1450+
llvm::zip(op.getOperandTypes(), op.getResultTypes()))
1451+
if (isSignednessCast(srcType, dstType))
1452+
return false;
1453+
return true;
1454+
});
13901455

13911456
// Fail hard when there are any remaining 'arith' ops.
13921457
target->addIllegalDialect<arith::ArithDialect>();

mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -743,6 +743,35 @@ func.func @bit_cast(%arg0: vector<2xf32>, %arg1: i64) {
743743
return
744744
}
745745

746+
// CHECK-LABEL: @unrealized_conversion_cast
747+
func.func @unrealized_conversion_cast(%arg0: vector<3xi64>, %arg1: i16, %arg2: f32) {
748+
// CHECK-NEXT: spirv.Bitcast %{{.+}} : vector<3xi64> to vector<3xui64>
749+
%0 = builtin.unrealized_conversion_cast %arg0 : vector<3xi64> to vector<3xui64>
750+
// CHECK-NEXT: spirv.Bitcast %{{.+}} : i16 to ui16
751+
%1 = builtin.unrealized_conversion_cast %arg1 : i16 to ui16
752+
753+
// CHECK-NEXT: spirv.Bitcast %{{.+}} : vector<3xi64> to vector<3xui64>
754+
// CHECK-NEXT: spirv.Bitcast %{{.+}} : i16 to ui16
755+
%2:2 = builtin.unrealized_conversion_cast %arg0, %arg1 : vector<3xi64>, i16 to vector<3xui64>, ui16
756+
757+
// CHECK-NEXT: spirv.Bitcast %{{.+}} : i16 to ui16
758+
%3:2 = builtin.unrealized_conversion_cast %arg0, %arg1 : vector<3xi64>, i16 to vector<3xi64>, ui16
759+
// CHECK-NEXT: spirv.Bitcast %{{.+}} : vector<3xi64> to vector<3xui64>
760+
%4:2 = builtin.unrealized_conversion_cast %arg0, %arg1 : vector<3xi64>, i16 to vector<3xui64>, i16
761+
762+
// bitcast from float to int should be represented using arith.bitcast
763+
// CHECK-NEXT: builtin.unrealized_conversion_cast %{{.+}} : f32 to i32
764+
%5 = builtin.unrealized_conversion_cast %arg2 : f32 to i32
765+
766+
// test mixed signedness and non-signedness cast
767+
// CHECK-NEXT: builtin.unrealized_conversion_cast %{{.+}} : f32 to f16
768+
// CHECK-NEXT: spirv.Bitcast %{{.+}} : i32 to ui32
769+
%6:2 = builtin.unrealized_conversion_cast %5, %arg2 : i32, f32 to ui32, f16
770+
771+
// CHECK-NEXT: return
772+
return
773+
}
774+
746775
// CHECK-LABEL: @fpext1
747776
func.func @fpext1(%arg0: f16) -> f64 {
748777
// CHECK: spirv.FConvert %{{.*}} : f16 to f64

0 commit comments

Comments
 (0)