Skip to content

Commit f9a8006

Browse files
authored
[mlir][emitc] Support convert arith.extf and arith.truncf to emitc (#121184)
1 parent 0195ec4 commit f9a8006

File tree

3 files changed

+106
-1
lines changed

3 files changed

+106
-1
lines changed

mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -733,6 +733,43 @@ class ItoFCastOpConversion : public OpConversionPattern<CastOp> {
733733
}
734734
};
735735

736+
// Floating-point to floating-point conversions.
737+
template <typename CastOp>
738+
class FpCastOpConversion : public OpConversionPattern<CastOp> {
739+
public:
740+
FpCastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
741+
: OpConversionPattern<CastOp>(typeConverter, context) {}
742+
743+
LogicalResult
744+
matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
745+
ConversionPatternRewriter &rewriter) const override {
746+
// Vectors in particular are not supported.
747+
Type operandType = adaptor.getIn().getType();
748+
if (!emitc::isSupportedFloatType(operandType))
749+
return rewriter.notifyMatchFailure(castOp,
750+
"unsupported cast source type");
751+
if (auto roundingModeOp =
752+
dyn_cast<arith::ArithRoundingModeInterface>(*castOp)) {
753+
// Only supporting default rounding mode as of now.
754+
if (roundingModeOp.getRoundingModeAttr())
755+
return rewriter.notifyMatchFailure(castOp, "unsupported rounding mode");
756+
}
757+
758+
Type dstType = this->getTypeConverter()->convertType(castOp.getType());
759+
if (!dstType)
760+
return rewriter.notifyMatchFailure(castOp, "type conversion failed");
761+
762+
if (!emitc::isSupportedFloatType(dstType))
763+
return rewriter.notifyMatchFailure(castOp,
764+
"unsupported cast destination type");
765+
766+
Value fpCastOperand = adaptor.getIn();
767+
rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand);
768+
769+
return success();
770+
}
771+
};
772+
736773
} // namespace
737774

738775
//===----------------------------------------------------------------------===//
@@ -778,7 +815,9 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
778815
ItoFCastOpConversion<arith::SIToFPOp>,
779816
ItoFCastOpConversion<arith::UIToFPOp>,
780817
FtoICastOpConversion<arith::FPToSIOp>,
781-
FtoICastOpConversion<arith::FPToUIOp>
818+
FtoICastOpConversion<arith::FPToUIOp>,
819+
FpCastOpConversion<arith::ExtFOp>,
820+
FpCastOpConversion<arith::TruncFOp>
782821
>(typeConverter, ctx);
783822
// clang-format on
784823
}

mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,3 +149,43 @@ func.func @arith_remui_vector(%arg0: vector<5xi32>, %arg1: vector<5xi32>) -> vec
149149
%divui = arith.remui %arg0, %arg1 : vector<5xi32>
150150
return %divui: vector<5xi32>
151151
}
152+
153+
// -----
154+
155+
func.func @arith_truncf(%arg0: f64) -> f32 {
156+
// expected-error @+1 {{failed to legalize operation 'arith.truncf'}}
157+
%truncd = arith.truncf %arg0 to_nearest_away : f64 to f32
158+
return %truncd : f32
159+
}
160+
161+
// -----
162+
163+
func.func @arith_extf_f128(%arg0: f32) -> f128 {
164+
// expected-error @+1 {{failed to legalize operation 'arith.extf'}}
165+
%extd = arith.extf %arg0 : f32 to f128
166+
return %extd : f128
167+
}
168+
169+
// -----
170+
171+
func.func @arith_truncf_f128(%arg0: f128) -> f32 {
172+
// expected-error @+1 {{failed to legalize operation 'arith.truncf'}}
173+
%truncd = arith.truncf %arg0 : f128 to f32
174+
return %truncd : f32
175+
}
176+
177+
// -----
178+
179+
func.func @arith_extf_vector(%arg0: vector<4xf32>) -> vector<4xf64> {
180+
// expected-error @+1 {{failed to legalize operation 'arith.extf'}}
181+
%extd = arith.extf %arg0 : vector<4xf32> to vector<4xf64>
182+
return %extd : vector<4xf64>
183+
}
184+
185+
// -----
186+
187+
func.func @arith_truncf_vector(%arg0: vector<4xf64>) -> vector<4xf32> {
188+
// expected-error @+1 {{failed to legalize operation 'arith.truncf'}}
189+
%truncd = arith.truncf %arg0 : vector<4xf64> to vector<4xf32>
190+
return %truncd : vector<4xf32>
191+
}

mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -739,3 +739,29 @@ func.func @arith_divui_remui(%arg0: i32, %arg1: i32) -> i32 {
739739

740740
return %div : i32
741741
}
742+
743+
// -----
744+
745+
func.func @arith_extf(%arg0: f16) -> f64 {
746+
// CHECK-LABEL: arith_extf
747+
// CHECK-SAME: (%[[Arg0:[^ ]*]]: f16)
748+
// CHECK: %[[Extd0:.*]] = emitc.cast %[[Arg0]] : f16 to f32
749+
%extd0 = arith.extf %arg0 : f16 to f32
750+
// CHECK: %[[Extd1:.*]] = emitc.cast %[[Extd0]] : f32 to f64
751+
%extd1 = arith.extf %extd0 : f32 to f64
752+
753+
return %extd1 : f64
754+
}
755+
756+
// -----
757+
758+
func.func @arith_truncf(%arg0: f64) -> f16 {
759+
// CHECK-LABEL: arith_truncf
760+
// CHECK-SAME: (%[[Arg0:[^ ]*]]: f64)
761+
// CHECK: %[[Truncd0:.*]] = emitc.cast %[[Arg0]] : f64 to f32
762+
%truncd0 = arith.truncf %arg0 : f64 to f32
763+
// CHECK: %[[Truncd1:.*]] = emitc.cast %[[Truncd0]] : f32 to f16
764+
%truncd1 = arith.truncf %truncd0 : f32 to f16
765+
766+
return %truncd1 : f16
767+
}

0 commit comments

Comments
 (0)