Skip to content

Commit fc72dde

Browse files
committed
Removed roundmode and satmode from f32x4tofpx4op base
1 parent 8effd8b commit fc72dde

File tree

4 files changed

+7
-158
lines changed

4 files changed

+7
-158
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1969,21 +1969,21 @@ def NVVM_ConvertF32x2ToBF16x2Op : NVVM_ConvertF32x2ToFPx2OpBase<"bf16x2", "conve
19691969

19701970
// Base class for stochastic rounding conversions from F32x4 to FPx4 formats
19711971
// (E4M3x4, E5M2x4, E2M3x4, E3M2x4, E2M1x4)
1972+
// These operations always use RS (stochastic rounding) mode with SATFINITE saturation.
19721973
class NVVM_ConvertF32x4ToFPx4OpBase<string dstFormat, string mnemonic, Type resultType> :
19731974
NVVM_Op<mnemonic, [Pure, NVVMRequiresSMa<[100, 103]>]>,
19741975
Results<(outs resultType:$dst)>,
19751976
Arguments<(ins VectorOfLengthAndType<[4], [F32]>:$src, I32:$rbits,
1976-
DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::RS">:$rnd,
1977-
DefaultValuedAttr<SaturationModeAttr, "SaturationMode::SATFINITE">:$sat,
19781977
DefaultValuedAttr<BoolAttr, "false">:$relu,
19791978
TypeAttr:$dstTy)> {
1980-
let summary = "Convert vector<4xf32> to packed " # dstFormat # " with stochastic rounding (.rs)";
1979+
let summary = "Convert vector<4xf32> to packed " # dstFormat # " with stochastic rounding (.rs) and satfinite";
19811980
let description = [{
19821981
Converts a vector<4xf32> to packed }] # dstFormat # [{ format using
1983-
stochastic rounding (.rs) mode with randomness provided by the `rbits`
1984-
parameter. The `dstTy` attribute specifies the target format. The `relu`
1985-
attribute clamps negative results to 0. The `sat` attribute determines
1986-
saturation behavior.
1982+
stochastic rounding (.rs) mode with SATFINITE saturation. Randomness is
1983+
provided by the `rbits` parameter. The `dstTy` attribute specifies the
1984+
target format. The `relu` attribute clamps negative results to 0.
1985+
1986+
Note: These operations always use RS rounding mode and SATFINITE saturation mode.
19871987

19881988
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
19891989
}];

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -386,14 +386,6 @@ LogicalResult ConvertF32x2ToBF16x2Op::verify() {
386386
LogicalResult ConvertF32x4ToF8x4Op::verify() {
387387
mlir::MLIRContext *ctx = getContext();
388388

389-
if (getRnd() != FPRoundingMode::RS)
390-
return emitOpError("Only RS rounding mode is supported for "
391-
"conversions from f32x4 to f8x4.");
392-
393-
if (getSat() == SaturationMode::NONE)
394-
return emitOpError("Only SATFINITE saturation mode is supported for "
395-
"conversions from f32x4 to f8x4.");
396-
397389
if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy()))
398390
return emitOpError("Only ")
399391
<< mlir::Float8E4M3FNType::get(ctx) << " and "
@@ -406,14 +398,6 @@ LogicalResult ConvertF32x4ToF8x4Op::verify() {
406398
LogicalResult ConvertF32x4ToF6x4Op::verify() {
407399
mlir::MLIRContext *ctx = getContext();
408400

409-
if (getRnd() != FPRoundingMode::RS)
410-
return emitOpError("Only RS rounding mode is supported for "
411-
"conversions from f32x4 to f6x4.");
412-
413-
if (getSat() == SaturationMode::NONE)
414-
return emitOpError("Only SATFINITE saturation mode is supported for "
415-
"conversions from f32x4 to f6x4.");
416-
417401
if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy()))
418402
return emitOpError("Only ")
419403
<< mlir::Float6E2M3FNType::get(ctx) << " and "
@@ -426,14 +410,6 @@ LogicalResult ConvertF32x4ToF6x4Op::verify() {
426410
LogicalResult ConvertF32x4ToF4x4Op::verify() {
427411
mlir::MLIRContext *ctx = getContext();
428412

429-
if (getRnd() != FPRoundingMode::RS)
430-
return emitOpError("Only RS rounding mode is supported for "
431-
"conversions from f32x4 to f4x4.");
432-
433-
if (getSat() == SaturationMode::NONE)
434-
return emitOpError("Only SATFINITE saturation mode is supported for "
435-
"conversions from f32x4 to f4x4.");
436-
437413
if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
438414
return emitOpError("Only ") << mlir::Float4E2M1FNType::get(ctx)
439415
<< " type is supported for conversions from "

mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir

Lines changed: 0 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -14,47 +14,6 @@ gpu.module @invalid_arch_sm_100 [#nvvm.target<chip = "sm_100">] {
1414

1515
// -----
1616

17-
// Test that FP8/FP6/FP4 conversions require satfinite mode
18-
llvm.func @invalid_sat_mode_f8x4_e4m3(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
19-
// expected-error@+1 {{Only SATFINITE saturation mode is supported for conversions from f32x4 to f8x4.}}
20-
%res = nvvm.convert.f32x4.to.f8x4 %src, %rbits {sat = #nvvm.sat_mode<none>} : vector<4xf32> -> vector<4xi8> (f8E4M3FN)
21-
llvm.return %res : vector<4xi8>
22-
}
23-
24-
// -----
25-
26-
llvm.func @invalid_sat_mode_f8x4_e5m2(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
27-
// expected-error@+1 {{Only SATFINITE saturation mode is supported for conversions from f32x4 to f8x4.}}
28-
%res = nvvm.convert.f32x4.to.f8x4 %src, %rbits {sat = #nvvm.sat_mode<none>} : vector<4xf32> -> vector<4xi8> (f8E5M2)
29-
llvm.return %res : vector<4xi8>
30-
}
31-
32-
// -----
33-
34-
llvm.func @invalid_sat_mode_f6x4_e2m3(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
35-
// expected-error@+1 {{Only SATFINITE saturation mode is supported for conversions from f32x4 to f6x4.}}
36-
%res = nvvm.convert.f32x4.to.f6x4 %src, %rbits {sat = #nvvm.sat_mode<none>} : vector<4xf32> -> vector<4xi8> (f6E2M3FN)
37-
llvm.return %res : vector<4xi8>
38-
}
39-
40-
// -----
41-
42-
llvm.func @invalid_sat_mode_f6x4_e3m2(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
43-
// expected-error@+1 {{Only SATFINITE saturation mode is supported for conversions from f32x4 to f6x4.}}
44-
%res = nvvm.convert.f32x4.to.f6x4 %src, %rbits {sat = #nvvm.sat_mode<none>} : vector<4xf32> -> vector<4xi8> (f6E3M2FN)
45-
llvm.return %res : vector<4xi8>
46-
}
47-
48-
// -----
49-
50-
llvm.func @invalid_sat_mode_f4x4_e2m1(%src : vector<4xf32>, %rbits : i32) -> i16 {
51-
// expected-error@+1 {{Only SATFINITE saturation mode is supported for conversions from f32x4 to f4x4.}}
52-
%res = nvvm.convert.f32x4.to.f4x4 %src, %rbits {sat = #nvvm.sat_mode<none>} : vector<4xf32> -> i16 (f4E2M1FN)
53-
llvm.return %res : i16
54-
}
55-
56-
// -----
57-
5817
// Test that operations require stochastic rounding mode
5918
llvm.func @invalid_rnd_mode_f16x2(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xf16> {
6019
// expected-error@+1 {{Only RS rounding mode is supported for conversions from f32x2 to f16x2.}}
@@ -72,22 +31,6 @@ llvm.func @invalid_rnd_mode_bf16x2(%srcA : f32, %srcB : f32, %rbits : i32) -> ve
7231

7332
// -----
7433

75-
llvm.func @invalid_rnd_mode_f8x4_e4m3(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
76-
// expected-error@+1 {{Only RS rounding mode is supported for conversions from f32x4 to f8x4.}}
77-
%res = nvvm.convert.f32x4.to.f8x4 %src, %rbits {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<4xf32> -> vector<4xi8> (f8E4M3FN)
78-
llvm.return %res : vector<4xi8>
79-
}
80-
81-
// -----
82-
83-
llvm.func @invalid_rnd_mode_f4x4_e2m1(%src : vector<4xf32>, %rbits : i32) -> i16 {
84-
// expected-error@+1 {{Only RS rounding mode is supported for conversions from f32x4 to f4x4.}}
85-
%res = nvvm.convert.f32x4.to.f4x4 %src, %rbits {rnd = #nvvm.fp_rnd_mode<rp>, sat = #nvvm.sat_mode<satfinite>} : vector<4xf32> -> i16 (f4E2M1FN)
86-
llvm.return %res : i16
87-
}
88-
89-
// -----
90-
9134
// Test invalid destination types for f8x4 (should only accept f8E4M3FN, f8E5M2)
9235
llvm.func @invalid_dst_type_f8x4_e3m4(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
9336
// expected-error@+1 {{Only 'f8E4M3FN' and 'f8E5M2' types are supported for conversions from f32x4 to f8x4.}}

mlir/test/Target/LLVMIR/nvvm/convert_stochastic_rounding.mlir

Lines changed: 0 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -101,27 +101,13 @@ llvm.func @convert_f32x4_to_f8x4_e4m3_rs(%src : vector<4xf32>, %rbits : i32) ->
101101
llvm.return %res : vector<4xi8>
102102
}
103103

104-
// CHECK-LABEL: @convert_f32x4_to_f8x4_e4m3_rs_satfinite
105-
llvm.func @convert_f32x4_to_f8x4_e4m3_rs_satfinite(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
106-
// CHECK: %{{.*}} = call <4 x i8> @llvm.nvvm.f32x4.to.e4m3x4.rs.satfinite(<4 x float> %{{.*}}, i32 %{{.*}})
107-
%res = nvvm.convert.f32x4.to.f8x4 %src, %rbits {sat = #nvvm.sat_mode<satfinite>} : vector<4xf32> -> vector<4xi8> (f8E4M3FN)
108-
llvm.return %res : vector<4xi8>
109-
}
110-
111104
// CHECK-LABEL: @convert_f32x4_to_f8x4_e4m3_rs_relu
112105
llvm.func @convert_f32x4_to_f8x4_e4m3_rs_relu(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
113106
// CHECK: %{{.*}} = call <4 x i8> @llvm.nvvm.f32x4.to.e4m3x4.rs.relu.satfinite(<4 x float> %{{.*}}, i32 %{{.*}})
114107
%res = nvvm.convert.f32x4.to.f8x4 %src, %rbits {relu = true} : vector<4xf32> -> vector<4xi8> (f8E4M3FN)
115108
llvm.return %res : vector<4xi8>
116109
}
117110

118-
// CHECK-LABEL: @convert_f32x4_to_f8x4_e4m3_rs_relu_satfinite
119-
llvm.func @convert_f32x4_to_f8x4_e4m3_rs_relu_satfinite(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
120-
// CHECK: %{{.*}} = call <4 x i8> @llvm.nvvm.f32x4.to.e4m3x4.rs.relu.satfinite(<4 x float> %{{.*}}, i32 %{{.*}})
121-
%res = nvvm.convert.f32x4.to.f8x4 %src, %rbits {relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<4xf32> -> vector<4xi8> (f8E4M3FN)
122-
llvm.return %res : vector<4xi8>
123-
}
124-
125111
// -----
126112

127113
// Test F32x4 -> F8x4 (E5M2) with stochastic rounding (.rs)
@@ -133,27 +119,13 @@ llvm.func @convert_f32x4_to_f8x4_e5m2_rs(%src : vector<4xf32>, %rbits : i32) ->
133119
llvm.return %res : vector<4xi8>
134120
}
135121

136-
// CHECK-LABEL: @convert_f32x4_to_f8x4_e5m2_rs_satfinite
137-
llvm.func @convert_f32x4_to_f8x4_e5m2_rs_satfinite(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
138-
// CHECK: %{{.*}} = call <4 x i8> @llvm.nvvm.f32x4.to.e5m2x4.rs.satfinite(<4 x float> %{{.*}}, i32 %{{.*}})
139-
%res = nvvm.convert.f32x4.to.f8x4 %src, %rbits {sat = #nvvm.sat_mode<satfinite>} : vector<4xf32> -> vector<4xi8> (f8E5M2)
140-
llvm.return %res : vector<4xi8>
141-
}
142-
143122
// CHECK-LABEL: @convert_f32x4_to_f8x4_e5m2_rs_relu
144123
llvm.func @convert_f32x4_to_f8x4_e5m2_rs_relu(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
145124
// CHECK: %{{.*}} = call <4 x i8> @llvm.nvvm.f32x4.to.e5m2x4.rs.relu.satfinite(<4 x float> %{{.*}}, i32 %{{.*}})
146125
%res = nvvm.convert.f32x4.to.f8x4 %src, %rbits {relu = true} : vector<4xf32> -> vector<4xi8> (f8E5M2)
147126
llvm.return %res : vector<4xi8>
148127
}
149128

150-
// CHECK-LABEL: @convert_f32x4_to_f8x4_e5m2_rs_relu_satfinite
151-
llvm.func @convert_f32x4_to_f8x4_e5m2_rs_relu_satfinite(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
152-
// CHECK: %{{.*}} = call <4 x i8> @llvm.nvvm.f32x4.to.e5m2x4.rs.relu.satfinite(<4 x float> %{{.*}}, i32 %{{.*}})
153-
%res = nvvm.convert.f32x4.to.f8x4 %src, %rbits {relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<4xf32> -> vector<4xi8> (f8E5M2)
154-
llvm.return %res : vector<4xi8>
155-
}
156-
157129
// -----
158130

159131
// Test F32x4 -> F6x4 (E2M3) with stochastic rounding (.rs)
@@ -165,27 +137,13 @@ llvm.func @convert_f32x4_to_f6x4_e2m3_rs(%src : vector<4xf32>, %rbits : i32) ->
165137
llvm.return %res : vector<4xi8>
166138
}
167139

168-
// CHECK-LABEL: @convert_f32x4_to_f6x4_e2m3_rs_satfinite
169-
llvm.func @convert_f32x4_to_f6x4_e2m3_rs_satfinite(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
170-
// CHECK: %{{.*}} = call <4 x i8> @llvm.nvvm.f32x4.to.e2m3x4.rs.satfinite(<4 x float> %{{.*}}, i32 %{{.*}})
171-
%res = nvvm.convert.f32x4.to.f6x4 %src, %rbits {sat = #nvvm.sat_mode<satfinite>} : vector<4xf32> -> vector<4xi8> (f6E2M3FN)
172-
llvm.return %res : vector<4xi8>
173-
}
174-
175140
// CHECK-LABEL: @convert_f32x4_to_f6x4_e2m3_rs_relu
176141
llvm.func @convert_f32x4_to_f6x4_e2m3_rs_relu(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
177142
// CHECK: %{{.*}} = call <4 x i8> @llvm.nvvm.f32x4.to.e2m3x4.rs.relu.satfinite(<4 x float> %{{.*}}, i32 %{{.*}})
178143
%res = nvvm.convert.f32x4.to.f6x4 %src, %rbits {relu = true} : vector<4xf32> -> vector<4xi8> (f6E2M3FN)
179144
llvm.return %res : vector<4xi8>
180145
}
181146

182-
// CHECK-LABEL: @convert_f32x4_to_f6x4_e2m3_rs_relu_satfinite
183-
llvm.func @convert_f32x4_to_f6x4_e2m3_rs_relu_satfinite(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
184-
// CHECK: %{{.*}} = call <4 x i8> @llvm.nvvm.f32x4.to.e2m3x4.rs.relu.satfinite(<4 x float> %{{.*}}, i32 %{{.*}})
185-
%res = nvvm.convert.f32x4.to.f6x4 %src, %rbits {relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<4xf32> -> vector<4xi8> (f6E2M3FN)
186-
llvm.return %res : vector<4xi8>
187-
}
188-
189147
// -----
190148

191149
// Test F32x4 -> F6x4 (E3M2) with stochastic rounding (.rs)
@@ -197,27 +155,13 @@ llvm.func @convert_f32x4_to_f6x4_e3m2_rs(%src : vector<4xf32>, %rbits : i32) ->
197155
llvm.return %res : vector<4xi8>
198156
}
199157

200-
// CHECK-LABEL: @convert_f32x4_to_f6x4_e3m2_rs_satfinite
201-
llvm.func @convert_f32x4_to_f6x4_e3m2_rs_satfinite(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
202-
// CHECK: %{{.*}} = call <4 x i8> @llvm.nvvm.f32x4.to.e3m2x4.rs.satfinite(<4 x float> %{{.*}}, i32 %{{.*}})
203-
%res = nvvm.convert.f32x4.to.f6x4 %src, %rbits {sat = #nvvm.sat_mode<satfinite>} : vector<4xf32> -> vector<4xi8> (f6E3M2FN)
204-
llvm.return %res : vector<4xi8>
205-
}
206-
207158
// CHECK-LABEL: @convert_f32x4_to_f6x4_e3m2_rs_relu
208159
llvm.func @convert_f32x4_to_f6x4_e3m2_rs_relu(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
209160
// CHECK: %{{.*}} = call <4 x i8> @llvm.nvvm.f32x4.to.e3m2x4.rs.relu.satfinite(<4 x float> %{{.*}}, i32 %{{.*}})
210161
%res = nvvm.convert.f32x4.to.f6x4 %src, %rbits {relu = true} : vector<4xf32> -> vector<4xi8> (f6E3M2FN)
211162
llvm.return %res : vector<4xi8>
212163
}
213164

214-
// CHECK-LABEL: @convert_f32x4_to_f6x4_e3m2_rs_relu_satfinite
215-
llvm.func @convert_f32x4_to_f6x4_e3m2_rs_relu_satfinite(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
216-
// CHECK: %{{.*}} = call <4 x i8> @llvm.nvvm.f32x4.to.e3m2x4.rs.relu.satfinite(<4 x float> %{{.*}}, i32 %{{.*}})
217-
%res = nvvm.convert.f32x4.to.f6x4 %src, %rbits {relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<4xf32> -> vector<4xi8> (f6E3M2FN)
218-
llvm.return %res : vector<4xi8>
219-
}
220-
221165
// -----
222166

223167
// Test F32x4 -> F4x4 (E2M1) with stochastic rounding (.rs)
@@ -229,24 +173,10 @@ llvm.func @convert_f32x4_to_f4x4_e2m1_rs(%src : vector<4xf32>, %rbits : i32) ->
229173
llvm.return %res : i16
230174
}
231175

232-
// CHECK-LABEL: @convert_f32x4_to_f4x4_e2m1_rs_satfinite
233-
llvm.func @convert_f32x4_to_f4x4_e2m1_rs_satfinite(%src : vector<4xf32>, %rbits : i32) -> i16 {
234-
// CHECK: %{{.*}} = call i16 @llvm.nvvm.f32x4.to.e2m1x4.rs.satfinite(<4 x float> %{{.*}}, i32 %{{.*}})
235-
%res = nvvm.convert.f32x4.to.f4x4 %src, %rbits {sat = #nvvm.sat_mode<satfinite>} : vector<4xf32> -> i16 (f4E2M1FN)
236-
llvm.return %res : i16
237-
}
238-
239176
// CHECK-LABEL: @convert_f32x4_to_f4x4_e2m1_rs_relu
240177
llvm.func @convert_f32x4_to_f4x4_e2m1_rs_relu(%src : vector<4xf32>, %rbits : i32) -> i16 {
241178
// CHECK: %{{.*}} = call i16 @llvm.nvvm.f32x4.to.e2m1x4.rs.relu.satfinite(<4 x float> %{{.*}}, i32 %{{.*}})
242179
%res = nvvm.convert.f32x4.to.f4x4 %src, %rbits {relu = true} : vector<4xf32> -> i16 (f4E2M1FN)
243180
llvm.return %res : i16
244181
}
245182

246-
// CHECK-LABEL: @convert_f32x4_to_f4x4_e2m1_rs_relu_satfinite
247-
llvm.func @convert_f32x4_to_f4x4_e2m1_rs_relu_satfinite(%src : vector<4xf32>, %rbits : i32) -> i16 {
248-
// CHECK: %{{.*}} = call i16 @llvm.nvvm.f32x4.to.e2m1x4.rs.relu.satfinite(<4 x float> %{{.*}}, i32 %{{.*}})
249-
%res = nvvm.convert.f32x4.to.f4x4 %src, %rbits {relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<4xf32> -> i16 (f4E2M1FN)
250-
llvm.return %res : i16
251-
}
252-

0 commit comments

Comments
 (0)