Skip to content

Commit ce6db11

Browse files
committed
[mlir][emitc] Support convert arith.extf and arith.truncf to emitc
1 parent ef77188 commit ce6db11

File tree

2 files changed

+60
-1
lines changed

2 files changed

+60
-1
lines changed

mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -733,6 +733,37 @@ class ItoFCastOpConversion : public OpConversionPattern<CastOp> {
733733
}
734734
};
735735

736+
// Floating-point to floating-point conversions.
737+
template <typename CastOp>
738+
class FpCastOpConversion : public OpConversionPattern<CastOp> {
739+
public:
740+
FpCastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
741+
: OpConversionPattern<CastOp>(typeConverter, context) {}
742+
743+
LogicalResult
744+
matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
745+
ConversionPatternRewriter &rewriter) const override {
746+
// Vectors in particular are not supported
747+
Type operandType = adaptor.getIn().getType();
748+
if (!emitc::isSupportedFloatType(operandType))
749+
return rewriter.notifyMatchFailure(castOp,
750+
"unsupported cast source type");
751+
752+
Type dstType = this->getTypeConverter()->convertType(castOp.getType());
753+
if (!dstType)
754+
return rewriter.notifyMatchFailure(castOp, "type conversion failed");
755+
756+
if (!emitc::isSupportedFloatType(dstType))
757+
return rewriter.notifyMatchFailure(castOp,
758+
"unsupported cast destination type");
759+
760+
Value fpCastOperand = adaptor.getIn();
761+
rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand);
762+
763+
return success();
764+
}
765+
};
766+
736767
} // namespace
737768

738769
//===----------------------------------------------------------------------===//
@@ -778,7 +809,9 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
778809
ItoFCastOpConversion<arith::SIToFPOp>,
779810
ItoFCastOpConversion<arith::UIToFPOp>,
780811
FtoICastOpConversion<arith::FPToSIOp>,
781-
FtoICastOpConversion<arith::FPToUIOp>
812+
FtoICastOpConversion<arith::FPToUIOp>,
813+
FpCastOpConversion<arith::ExtFOp>,
814+
FpCastOpConversion<arith::TruncFOp>
782815
>(typeConverter, ctx);
783816
// clang-format on
784817
}

mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -739,3 +739,29 @@ func.func @arith_divui_remui(%arg0: i32, %arg1: i32) -> i32 {
739739

740740
return %div : i32
741741
}
742+
743+
// -----
744+
745+
func.func @arith_extf(%arg0: f16) -> f64 {
746+
// CHECK-LABEL: arith_extf
747+
// CHECK-SAME: (%[[Arg0:[^ ]*]]: f16)
748+
// CHECK: %[[Extd0:.*]] = emitc.cast %[[Arg0]] : f16 to f32
749+
%extd0 = arith.extf %arg0 : f16 to f32
750+
// CHECK: %[[Extd1:.*]] = emitc.cast %[[Extd0]] : f32 to f64
751+
%extd1 = arith.extf %extd0 : f32 to f64
752+
753+
return %extd1 : f64
754+
}
755+
756+
// -----
757+
758+
func.func @arith_truncf(%arg0: f64) -> f16 {
759+
// CHECK-LABEL: arith_truncf
760+
// CHECK-SAME: (%[[Arg0:[^ ]*]]: f64)
761+
// CHECK: %[[Truncd0:.*]] = emitc.cast %[[Arg0]] : f64 to f32
762+
%truncd0 = arith.truncf %arg0 : f64 to f32
763+
// CHECK: %[[Truncd1:.*]] = emitc.cast %[[Truncd0]] : f32 to f16
764+
%truncd1 = arith.truncf %truncd0 : f32 to f16
765+
766+
return %truncd1 : f16
767+
}

0 commit comments

Comments
 (0)