Skip to content

Commit 5f29a9d

Browse files
authored
[FXML-4642] Lower extf, truncf to EmitC (#224)
1 parent 34e3674 commit 5f29a9d

File tree

5 files changed

+260
-12
lines changed

5 files changed

+260
-12
lines changed

mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -733,6 +733,70 @@ class ItoFCastOpConversion : public OpConversionPattern<CastOp> {
733733
}
734734
};
735735

736+
class TruncFConversion : public OpConversionPattern<arith::TruncFOp> {
737+
public:
738+
using OpConversionPattern<arith::TruncFOp>::OpConversionPattern;
739+
740+
LogicalResult
741+
matchAndRewrite(arith::TruncFOp castOp,
742+
typename arith::TruncFOp::Adaptor adaptor,
743+
ConversionPatternRewriter &rewriter) const override {
744+
// FIXME Upstream LLVM (commit 77cbc9bf60) brings in a rounding mode
745+
// attribute that we need to check. For now, the behavior is the default,
746+
// i.e. truncate.
747+
Type operandType = adaptor.getIn().getType();
748+
if (!emitc::isSupportedFloatType(operandType))
749+
return rewriter.notifyMatchFailure(castOp,
750+
"unsupported cast source type");
751+
752+
Type dstType = this->getTypeConverter()->convertType(castOp.getType());
753+
if (!dstType)
754+
return rewriter.notifyMatchFailure(castOp, "type conversion failed");
755+
756+
if (!emitc::isSupportedFloatType(dstType))
757+
return rewriter.notifyMatchFailure(castOp,
758+
"unsupported cast destination type");
759+
760+
if (!castOp.areCastCompatible(operandType, dstType))
761+
return rewriter.notifyMatchFailure(castOp, "cast-incompatible types");
762+
763+
rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType,
764+
adaptor.getIn());
765+
766+
return success();
767+
}
768+
};
769+
770+
class ExtFConversion : public OpConversionPattern<arith::ExtFOp> {
771+
public:
772+
using OpConversionPattern<arith::ExtFOp>::OpConversionPattern;
773+
774+
LogicalResult
775+
matchAndRewrite(arith::ExtFOp castOp, typename arith::ExtFOp::Adaptor adaptor,
776+
ConversionPatternRewriter &rewriter) const override {
777+
Type operandType = adaptor.getIn().getType();
778+
if (!emitc::isSupportedFloatType(operandType))
779+
return rewriter.notifyMatchFailure(castOp,
780+
"unsupported cast source type");
781+
782+
Type dstType = this->getTypeConverter()->convertType(castOp.getType());
783+
if (!dstType)
784+
return rewriter.notifyMatchFailure(castOp, "type conversion failed");
785+
786+
if (!emitc::isSupportedFloatType(dstType))
787+
return rewriter.notifyMatchFailure(castOp,
788+
"unsupported cast destination type");
789+
790+
if (!castOp.areCastCompatible(operandType, dstType))
791+
return rewriter.notifyMatchFailure(castOp, "cast-incompatible types");
792+
793+
rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType,
794+
adaptor.getIn());
795+
796+
return success();
797+
}
798+
};
799+
736800
} // namespace
737801

738802
//===----------------------------------------------------------------------===//
@@ -778,7 +842,9 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
778842
ItoFCastOpConversion<arith::SIToFPOp>,
779843
ItoFCastOpConversion<arith::UIToFPOp>,
780844
FtoICastOpConversion<arith::FPToSIOp>,
781-
FtoICastOpConversion<arith::FPToUIOp>
845+
FtoICastOpConversion<arith::FPToUIOp>,
846+
TruncFConversion,
847+
ExtFConversion
782848
>(typeConverter, ctx);
783849
// clang-format on
784850
}

mlir/lib/Dialect/EmitC/IR/EmitC.cpp

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -114,16 +114,7 @@ bool mlir::emitc::isIntegerIndexOrOpaqueType(Type type) {
114114
}
115115

116116
bool mlir::emitc::isSupportedFloatType(Type type) {
117-
if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
118-
switch (floatType.getWidth()) {
119-
case 32:
120-
case 64:
121-
return true;
122-
default:
123-
return false;
124-
}
125-
}
126-
return false;
117+
return isa<Float32Type, Float64Type>(type);
127118
}
128119

129120
bool mlir::emitc::isPointerWideType(Type type) {

mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,116 @@ func.func @arith_remui_vector(%arg0: vector<5xi32>, %arg1: vector<5xi32>) -> vec
134134
%divui = arith.remui %arg0, %arg1 : vector<5xi32>
135135
return %divui: vector<5xi32>
136136
}
137+
138+
// -----
139+
140+
func.func @arith_extf_to_bf16(%arg0: f8E4M3FN) {
141+
// expected-error @+1 {{failed to legalize operation 'arith.extf'}}
142+
%ext = arith.extf %arg0 : f8E4M3FN to bf16
143+
return
144+
}
145+
146+
// -----
147+
148+
func.func @arith_extf_to_f16(%arg0: f8E4M3FN) {
149+
// expected-error @+1 {{failed to legalize operation 'arith.extf'}}
150+
%ext = arith.extf %arg0 : f8E4M3FN to f16
151+
return
152+
}
153+
154+
155+
// -----
156+
157+
func.func @arith_extf_to_tf32(%arg0: f8E4M3FN) {
158+
// expected-error @+1 {{failed to legalize operation 'arith.extf'}}
159+
%ext = arith.extf %arg0 : f8E4M3FN to tf32
160+
return
161+
}
162+
163+
// -----
164+
165+
func.func @arith_extf_to_float80(%arg0: f8E4M3FN) {
166+
// expected-error @+1 {{failed to legalize operation 'arith.extf'}}
167+
%ext = arith.extf %arg0 : f8E4M3FN to f80
168+
return
169+
}
170+
171+
// -----
172+
173+
func.func @arith_extf_to_float128(%arg0: f8E4M3FN) {
174+
// expected-error @+1 {{failed to legalize operation 'arith.extf'}}
175+
%ext = arith.extf %arg0 : f8E4M3FN to f128
176+
return
177+
}
178+
179+
// -----
180+
181+
func.func @arith_truncf_to_f80(%arg0: f128) {
182+
// expected-error @+1 {{failed to legalize operation 'arith.truncf'}}
183+
%trunc = arith.truncf %arg0 : f128 to f80
184+
return
185+
}
186+
187+
// -----
188+
189+
func.func @arith_truncf_to_tf32(%arg0: f64) {
190+
// expected-error @+1 {{failed to legalize operation 'arith.truncf'}}
191+
%trunc = arith.truncf %arg0 : f64 to tf32
192+
return
193+
}
194+
195+
// -----
196+
197+
func.func @arith_truncf_to_f16(%arg0: f64) {
198+
// expected-error @+1 {{failed to legalize operation 'arith.truncf'}}
199+
%trunc = arith.truncf %arg0 : f64 to f16
200+
return
201+
}
202+
203+
// -----
204+
205+
func.func @arith_truncf_to_bf16(%arg0: f64) {
206+
// expected-error @+1 {{failed to legalize operation 'arith.truncf'}}
207+
%trunc = arith.truncf %arg0 : f64 to bf16
208+
return
209+
}
210+
211+
// -----
212+
213+
func.func @arith_truncf_to_f8E4M3FN(%arg0: f64) {
214+
// expected-error @+1 {{failed to legalize operation 'arith.truncf'}}
215+
%trunc = arith.truncf %arg0 : f64 to f8E4M3FN
216+
return
217+
}
218+
219+
// -----
220+
221+
func.func @arith_truncf_to_f8E5M2(%arg0: f64) {
222+
// expected-error @+1 {{failed to legalize operation 'arith.truncf'}}
223+
%trunc = arith.truncf %arg0 : f64 to f8E5M2
224+
return
225+
}
226+
227+
// -----
228+
229+
func.func @arith_truncf_to_f8E4M3FNUZ(%arg0: f64) {
230+
// expected-error @+1 {{failed to legalize operation 'arith.truncf'}}
231+
%trunc = arith.truncf %arg0 : f64 to f8E4M3FNUZ
232+
return
233+
}
234+
235+
// -----
236+
237+
func.func @arith_truncf_to_f8E4M3FN(%arg0: f64) {
238+
// expected-error @+1 {{failed to legalize operation 'arith.truncf'}}
239+
%trunc = arith.truncf %arg0 : f64 to f8E4M3FN
240+
return
241+
}
242+
243+
// -----
244+
245+
func.func @arith_truncf_to_f8E4M3B11FNUZ(%arg0: f64) {
246+
// expected-error @+1 {{failed to legalize operation 'arith.truncf'}}
247+
%trunc = arith.truncf %arg0 : f64 to f8E4M3B11FNUZ
248+
return
249+
}

mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -717,3 +717,17 @@ func.func @arith_index_castui(%arg0: i32) -> i32 {
717717

718718
return %int : i32
719719
}
720+
// -----
721+
722+
func.func @arith_extf_truncf(%arg0: f32, %arg1: f64) {
723+
// CHECK-LABEL: arith_extf_truncf
724+
// CHECK-SAME: (%[[Arg0:[^ ]*]]: f32, %[[Arg1:[^ ]*]]: f64)
725+
726+
// CHECK: emitc.cast %[[Arg0]] : f32 to f64
727+
%ext = arith.extf %arg0 : f32 to f64
728+
729+
// CHECK: emitc.cast %[[Arg1]] : f64 to f32
730+
%trunc = arith.truncf %arg1 : f64 to f32
731+
732+
return
733+
}

mlir/test/Dialect/EmitC/invalid_types.mlir

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,14 +92,78 @@ func.func @illegal_integer_type(%arg0: i11, %arg1: i11) -> i11 {
9292

9393
// -----
9494

95-
func.func @illegal_float_type(%arg0: f80, %arg1: f80) {
95+
func.func @illegal_f8E4M3B11FNUZ_type(%arg0: f8E4M3B11FNUZ, %arg1: f8E4M3B11FNUZ) {
96+
// expected-error @+1 {{'emitc.mul' op operand #0 must be floating-point type supported by EmitC or integer, index or opaque type supported by EmitC, but got 'f8E4M3B11FNUZ'}}
97+
%mul = "emitc.mul" (%arg0, %arg1) : (f8E4M3B11FNUZ, f8E4M3B11FNUZ) -> f8E4M3B11FNUZ
98+
return
99+
}
100+
101+
// -----
102+
103+
func.func @illegal_f8E4M3FN_type(%arg0: f8E4M3FN, %arg1: f8E4M3FN) {
104+
// expected-error @+1 {{'emitc.mul' op operand #0 must be floating-point type supported by EmitC or integer, index or opaque type supported by EmitC, but got 'f8E4M3FN'}}
105+
%mul = "emitc.mul" (%arg0, %arg1) : (f8E4M3FN, f8E4M3FN) -> f8E4M3FN
106+
return
107+
}
108+
109+
// -----
110+
111+
func.func @illegal_f8E4M3FNUZ_type(%arg0: f8E4M3FNUZ, %arg1: f8E4M3FNUZ) {
112+
// expected-error @+1 {{'emitc.mul' op operand #0 must be floating-point type supported by EmitC or integer, index or opaque type supported by EmitC, but got 'f8E4M3FNUZ'}}
113+
%mul = "emitc.mul" (%arg0, %arg1) : (f8E4M3FNUZ, f8E4M3FNUZ) -> f8E4M3FNUZ
114+
return
115+
}
116+
117+
// -----
118+
119+
func.func @illegal_f8E5M2_type(%arg0: f8E5M2, %arg1: f8E5M2) {
120+
// expected-error @+1 {{'emitc.mul' op operand #0 must be floating-point type supported by EmitC or integer, index or opaque type supported by EmitC, but got 'f8E5M2'}}
121+
%mul = "emitc.mul" (%arg0, %arg1) : (f8E5M2, f8E5M2) -> f8E5M2
122+
return
123+
}
124+
125+
// -----
126+
127+
func.func @illegal_f8E5M2FNUZ_type(%arg0: f8E5M2FNUZ, %arg1: f8E5M2FNUZ) {
128+
// expected-error @+1 {{'emitc.mul' op operand #0 must be floating-point type supported by EmitC or integer, index or opaque type supported by EmitC, but got 'f8E5M2FNUZ'}}
129+
%mul = "emitc.mul" (%arg0, %arg1) : (f8E5M2FNUZ, f8E5M2FNUZ) -> f8E5M2FNUZ
130+
return
131+
}
132+
133+
// -----
134+
135+
func.func @illegal_f16_type(%arg0: f16, %arg1: f16) {
136+
// expected-error @+1 {{'emitc.mul' op operand #0 must be floating-point type supported by EmitC or integer, index or opaque type supported by EmitC, but got 'f16'}}
137+
%mul = "emitc.mul" (%arg0, %arg1) : (f16, f16) -> f16
138+
return
139+
}
140+
141+
// -----
142+
143+
func.func @illegal_bf16_type(%arg0: bf16, %arg1: bf16) {
144+
// expected-error @+1 {{'emitc.mul' op operand #0 must be floating-point type supported by EmitC or integer, index or opaque type supported by EmitC, but got 'bf16'}}
145+
%mul = "emitc.mul" (%arg0, %arg1) : (bf16, bf16) -> bf16
146+
return
147+
}
148+
149+
// -----
150+
151+
func.func @illegal_f80_type(%arg0: f80, %arg1: f80) {
96152
// expected-error @+1 {{'emitc.mul' op operand #0 must be floating-point type supported by EmitC or integer, index or opaque type supported by EmitC, but got 'f80'}}
97153
%mul = "emitc.mul" (%arg0, %arg1) : (f80, f80) -> f80
98154
return
99155
}
100156

101157
// -----
102158

159+
func.func @illegal_f128_type(%arg0: f128, %arg1: f128) {
160+
// expected-error @+1 {{'emitc.mul' op operand #0 must be floating-point type supported by EmitC or integer, index or opaque type supported by EmitC, but got 'f128'}}
161+
%mul = "emitc.mul" (%arg0, %arg1) : (f128, f128) -> f128
162+
return
163+
}
164+
165+
// -----
166+
103167
func.func @illegal_pointee_type() {
104168
// expected-error @+1 {{'emitc.variable' op result #0 must be type supported by EmitC, but got '!emitc.ptr<i11>'}}
105169
%v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> !emitc.ptr<i11>

0 commit comments

Comments
 (0)