Skip to content

Commit 2a42a85

Browse files
authored
[MLIR][NVVM] Add support for Convert Ops with rs-rounding mode (#165736)
Added NVVM dialect operations for stochastic rounding (.rs) conversions from F32 to various packed floating-point formats. These operations map to existing PTX instructions and LLVM intrinsics. Supported conversions: - F32x2 to F16x2/BF16x2 (with optional relu and satfinite modifiers) - F32x4 to packed F8 formats (E4M3, E5M2) - F32x4 to packed F6 formats (E2M3, E3M2) - F32x4 to packed F4 format (E2M1) All operations support stochastic rounding with randomness provided via an rbits parameter, and optional relu and saturation modifiers.
1 parent 475c632 commit 2a42a85

File tree

4 files changed

+496
-1
lines changed

4 files changed

+496
-1
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1589,10 +1589,11 @@ def FPRoundingModeRM : I32EnumAttrCase<"RM", 2, "rm">;
15891589
def FPRoundingModeRP : I32EnumAttrCase<"RP", 3, "rp">;
15901590
def FPRoundingModeRZ : I32EnumAttrCase<"RZ", 4, "rz">;
15911591
def FPRoundingModeRNA : I32EnumAttrCase<"RNA", 5, "rna">;
1592+
def FPRoundingModeRS : I32EnumAttrCase<"RS", 6, "rs">;
15921593

15931594
def FPRoundingMode : I32EnumAttr<"FPRoundingMode", "NVVM FPRoundingMode kind",
15941595
[FPRoundingModeNone, FPRoundingModeRN, FPRoundingModeRM,
1595-
FPRoundingModeRP, FPRoundingModeRZ, FPRoundingModeRNA]> {
1596+
FPRoundingModeRP, FPRoundingModeRZ, FPRoundingModeRNA, FPRoundingModeRS]> {
15961597
let genSpecializedAttr = 0;
15971598
let cppNamespace = "::mlir::NVVM";
15981599
}
@@ -1906,6 +1907,96 @@ def NVVM_ConvertF6x2ToF16x2Op :
19061907
def NVVM_ConvertF4x2ToF16x2Op :
19071908
NVVM_ConvertToFP16x2Op_Base<"F4", I8, "F16">;
19081909

1910+
//===----------------------------------------------------------------------===//
1911+
// NVVM Stochastic Rounding Conversion Ops
1912+
//===----------------------------------------------------------------------===//
1913+
1914+
// Base class for conversions from F32x2 to FPx2 formats
1915+
// (F16x2, BF16x2)
1916+
// TODO: In separate PR, add .rn and .rz rounding variants for this conversion
1917+
// as currently only support .rs rounding mode
1918+
class NVVM_ConvertF32x2ToFPx2OpBase<string dstFormat, string mnemonic, Type dstType> :
1919+
NVVM_Op<mnemonic, [Pure, NVVMRequiresSMa<[100, 103]>]>,
1920+
Results<(outs dstType:$dst)>,
1921+
Arguments<(ins F32:$src_hi, F32:$src_lo, I32:$rbits,
1922+
DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::RS">:$rnd,
1923+
DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
1924+
DefaultValuedAttr<BoolAttr, "false">:$relu)> {
1925+
let summary = "Convert two F32 values to packed " # dstFormat # " with stochastic rounding (.rs)";
1926+
let description = [{
1927+
Converts two F32 values to packed }] # dstFormat # [{ format using stochastic
1928+
rounding (.rs) mode with randomness provided by the `rbits` parameter. The
1929+
`relu` attribute clamps negative results to 0. The `sat` attribute determines
1930+
saturation behavior. The `src_hi` and `src_lo` parameters correspond to operands
1931+
`a` and `b` in the PTX ISA, respectively.
1932+
1933+
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
1934+
}];
1935+
1936+
let assemblyFormat = "$src_hi `,` $src_lo `,` $rbits attr-dict `:` type($dst)";
1937+
1938+
let hasVerifier = 1;
1939+
1940+
let extraClassDeclaration = [{
1941+
llvm::Intrinsic::ID getIntrinsicID();
1942+
}];
1943+
1944+
string llvmBuilder = [{
1945+
auto intId = op.getIntrinsicID();
1946+
$dst = createIntrinsicCall(builder, intId, {$src_hi, $src_lo, $rbits});
1947+
}];
1948+
}
1949+
1950+
// F32x2 -> F16x2 with stochastic rounding
1951+
def NVVM_ConvertF32x2ToF16x2Op : NVVM_ConvertF32x2ToFPx2OpBase<"f16x2", "convert.f32x2.to.f16x2", VectorOfLengthAndType<[2], [F16]>>;
1952+
1953+
// F32x2 -> BF16x2 with stochastic rounding
1954+
def NVVM_ConvertF32x2ToBF16x2Op : NVVM_ConvertF32x2ToFPx2OpBase<"bf16x2", "convert.f32x2.to.bf16x2", VectorOfLengthAndType<[2], [BF16]>>;
1955+
1956+
// Base class for stochastic rounding conversions from F32x4 to FPx4 formats
1957+
// (E4M3x4, E5M2x4, E2M3x4, E3M2x4, E2M1x4)
1958+
// These operations always use RS (stochastic rounding) mode with SATFINITE saturation.
1959+
class NVVM_ConvertF32x4ToFPx4OpBase<string dstFormat, string mnemonic, Type dstType> :
1960+
NVVM_Op<mnemonic, [Pure, NVVMRequiresSMa<[100, 103]>]>,
1961+
Results<(outs dstType:$dst)>,
1962+
Arguments<(ins VectorOfLengthAndType<[4], [F32]>:$src, I32:$rbits,
1963+
DefaultValuedAttr<BoolAttr, "false">:$relu,
1964+
TypeAttr:$dstTy)> {
1965+
let summary = "Convert vector<4xf32> to packed " # dstFormat # " with stochastic rounding (.rs) and satfinite";
1966+
let description = [{
1967+
Converts a vector<4xf32> to packed }] # dstFormat # [{ format using
1968+
stochastic rounding (.rs) mode with SATFINITE saturation. Randomness is
1969+
provided by the `rbits` parameter. The `dstTy` attribute specifies the
1970+
target floating-point format. The `relu` attribute clamps negative results to 0.
1971+
1972+
Note: These operations always use RS rounding mode and SATFINITE saturation mode.
1973+
1974+
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
1975+
}];
1976+
1977+
let assemblyFormat = "$src `,` $rbits attr-dict `:` type($src) `->` type($dst) `(` $dstTy `)`";
1978+
1979+
let hasVerifier = 1;
1980+
1981+
let extraClassDeclaration = [{
1982+
llvm::Intrinsic::ID getIntrinsicID();
1983+
}];
1984+
1985+
string llvmBuilder = [{
1986+
auto intId = op.getIntrinsicID();
1987+
$dst = createIntrinsicCall(builder, intId, {$src, $rbits});
1988+
}];
1989+
}
1990+
1991+
// F32x4 -> F8x4 with stochastic rounding (supports E4M3FN, E5M2)
1992+
def NVVM_ConvertF32x4ToF8x4Op : NVVM_ConvertF32x4ToFPx4OpBase<"f8x4", "convert.f32x4.to.f8x4", VectorOfLengthAndType<[4], [I8]>>;
1993+
1994+
// F32x4 -> F6x4 with stochastic rounding (supports E2M3FN, E3M2FN)
1995+
def NVVM_ConvertF32x4ToF6x4Op : NVVM_ConvertF32x4ToFPx4OpBase<"f6x4", "convert.f32x4.to.f6x4", VectorOfLengthAndType<[4], [I8]>>;
1996+
1997+
// F32x4 -> F4x4 with stochastic rounding (supports E2M1FN)
1998+
def NVVM_ConvertF32x4ToF4x4Op : NVVM_ConvertF32x4ToFPx4OpBase<"f4x4", "convert.f32x4.to.f4x4", I16>;
1999+
19092000
//===----------------------------------------------------------------------===//
19102001
// NVVM MMA Ops
19112002
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,59 @@ LogicalResult ConvertF4x2ToF16x2Op::verify() {
365365
return success();
366366
}
367367

368+
//===----------------------------------------------------------------------===//
369+
// Stochastic Rounding Conversion Ops
370+
//===----------------------------------------------------------------------===//
371+
372+
LogicalResult ConvertF32x2ToF16x2Op::verify() {
373+
if (getRnd() != FPRoundingMode::RS)
374+
return emitOpError("Only RS rounding mode is supported for "
375+
"conversions from f32x2 to f16x2.");
376+
return success();
377+
}
378+
379+
LogicalResult ConvertF32x2ToBF16x2Op::verify() {
380+
if (getRnd() != FPRoundingMode::RS)
381+
return emitOpError("Only RS rounding mode is supported for "
382+
"conversions from f32x2 to bf16x2.");
383+
return success();
384+
}
385+
386+
LogicalResult ConvertF32x4ToF8x4Op::verify() {
387+
mlir::MLIRContext *ctx = getContext();
388+
389+
if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy()))
390+
return emitOpError("Only ")
391+
<< mlir::Float8E4M3FNType::get(ctx) << " and "
392+
<< mlir::Float8E5M2Type::get(ctx)
393+
<< " types are supported for conversions from f32x4 to f8x4.";
394+
395+
return success();
396+
}
397+
398+
LogicalResult ConvertF32x4ToF6x4Op::verify() {
399+
mlir::MLIRContext *ctx = getContext();
400+
401+
if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy()))
402+
return emitOpError("Only ")
403+
<< mlir::Float6E2M3FNType::get(ctx) << " and "
404+
<< mlir::Float6E3M2FNType::get(ctx)
405+
<< " types are supported for conversions from f32x4 to f6x4.";
406+
407+
return success();
408+
}
409+
410+
LogicalResult ConvertF32x4ToF4x4Op::verify() {
411+
mlir::MLIRContext *ctx = getContext();
412+
413+
if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
414+
return emitOpError("Only ") << mlir::Float4E2M1FNType::get(ctx)
415+
<< " type is supported for conversions from "
416+
"f32x4 to f4x4.";
417+
418+
return success();
419+
}
420+
368421
LogicalResult BulkStoreOp::verify() {
369422
if (getInitVal() != 0)
370423
return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
@@ -2469,6 +2522,85 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
24692522
return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
24702523
}()
24712524

2525+
llvm::Intrinsic::ID ConvertF32x2ToF16x2Op::getIntrinsicID() {
2526+
bool hasRelu = getRelu();
2527+
bool hasSatFinite = (getSat() == NVVM::SaturationMode::SATFINITE);
2528+
2529+
if (hasRelu && hasSatFinite)
2530+
return llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite;
2531+
if (hasRelu)
2532+
return llvm::Intrinsic::nvvm_ff2f16x2_rs_relu;
2533+
if (hasSatFinite)
2534+
return llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite;
2535+
return llvm::Intrinsic::nvvm_ff2f16x2_rs;
2536+
}
2537+
2538+
llvm::Intrinsic::ID ConvertF32x2ToBF16x2Op::getIntrinsicID() {
2539+
bool hasRelu = getRelu();
2540+
bool hasSatFinite = (getSat() == NVVM::SaturationMode::SATFINITE);
2541+
2542+
if (hasRelu && hasSatFinite)
2543+
return llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite;
2544+
if (hasRelu)
2545+
return llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu;
2546+
if (hasSatFinite)
2547+
return llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite;
2548+
return llvm::Intrinsic::nvvm_ff2bf16x2_rs;
2549+
}
2550+
2551+
llvm::Intrinsic::ID ConvertF32x4ToF8x4Op::getIntrinsicID() {
2552+
mlir::Type dstTy = getDstTy();
2553+
bool hasRelu = getRelu();
2554+
2555+
return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
2556+
.Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
2557+
return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite
2558+
: llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite;
2559+
})
2560+
.Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
2561+
return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite
2562+
: llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite;
2563+
})
2564+
.Default([](mlir::Type) {
2565+
llvm_unreachable("Invalid F8 type in ConvertF32x4ToF8x4Op");
2566+
return llvm::Intrinsic::not_intrinsic;
2567+
});
2568+
}
2569+
2570+
llvm::Intrinsic::ID ConvertF32x4ToF6x4Op::getIntrinsicID() {
2571+
mlir::Type dstTy = getDstTy();
2572+
bool hasRelu = getRelu();
2573+
2574+
return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
2575+
.Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
2576+
return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite
2577+
: llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite;
2578+
})
2579+
.Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
2580+
return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite
2581+
: llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite;
2582+
})
2583+
.Default([](mlir::Type) {
2584+
llvm_unreachable("Invalid F6 type in ConvertF32x4ToF6x4Op");
2585+
return llvm::Intrinsic::not_intrinsic;
2586+
});
2587+
}
2588+
2589+
llvm::Intrinsic::ID ConvertF32x4ToF4x4Op::getIntrinsicID() {
2590+
mlir::Type dstTy = getDstTy();
2591+
bool hasRelu = getRelu();
2592+
2593+
return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
2594+
.Case<mlir::Float4E2M1FNType>([&](mlir::Float4E2M1FNType) {
2595+
return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite
2596+
: llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite;
2597+
})
2598+
.Default([](mlir::Type) {
2599+
llvm_unreachable("Invalid F4 type in ConvertF32x4ToF4x4Op");
2600+
return llvm::Intrinsic::not_intrinsic;
2601+
});
2602+
}
2603+
24722604
llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(Operation &op) {
24732605
auto curOp = cast<NVVM::Tcgen05CpOp>(op);
24742606
bool is2CTA = curOp.getGroup() == CTAGroupKind::CTA_2;
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
2+
3+
// Test invalid target architecture (sm_100 instead of sm_100a)
4+
gpu.module @invalid_arch_sm_100 [#nvvm.target<chip = "sm_100">] {
5+
func.func @convert_rs() {
6+
%f1 = llvm.mlir.constant(1.0 : f32) : f32
7+
%f2 = llvm.mlir.constant(2.0 : f32) : f32
8+
%rbits = llvm.mlir.constant(0x12345678 : i32) : i32
9+
// expected-error@+1 {{'nvvm.convert.f32x2.to.f16x2' op is not supported on sm_100}}
10+
%res = nvvm.convert.f32x2.to.f16x2 %f1, %f2, %rbits : vector<2xf16>
11+
return
12+
}
13+
}
14+
15+
// -----
16+
17+
// Test that operations require stochastic rounding mode
18+
llvm.func @invalid_rnd_mode_f16x2(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xf16> {
19+
// expected-error@+1 {{Only RS rounding mode is supported for conversions from f32x2 to f16x2.}}
20+
%res = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xf16>
21+
llvm.return %res : vector<2xf16>
22+
}
23+
24+
// -----
25+
26+
llvm.func @invalid_rnd_mode_bf16x2(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xbf16> {
27+
// expected-error@+1 {{Only RS rounding mode is supported for conversions from f32x2 to bf16x2.}}
28+
%res = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xbf16>
29+
llvm.return %res : vector<2xbf16>
30+
}
31+
32+
// -----
33+
34+
// Test invalid destination types for f8x4 (should only accept f8E4M3FN, f8E5M2)
35+
llvm.func @invalid_dst_type_f8x4_e3m4(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
36+
// expected-error@+1 {{Only 'f8E4M3FN' and 'f8E5M2' types are supported for conversions from f32x4 to f8x4.}}
37+
%res = nvvm.convert.f32x4.to.f8x4 %src, %rbits : vector<4xf32> -> vector<4xi8> (f8E3M4)
38+
llvm.return %res : vector<4xi8>
39+
}
40+
41+
// -----
42+
43+
llvm.func @invalid_dst_type_f8x4_e8m0(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
44+
// expected-error@+1 {{Only 'f8E4M3FN' and 'f8E5M2' types are supported for conversions from f32x4 to f8x4.}}
45+
%res = nvvm.convert.f32x4.to.f8x4 %src, %rbits : vector<4xf32> -> vector<4xi8> (f8E8M0FNU)
46+
llvm.return %res : vector<4xi8>
47+
}
48+
49+
// -----
50+
51+
// Test invalid destination types for f6x4 (should only accept f6E2M3FN, f6E3M2FN)
52+
llvm.func @invalid_dst_type_f6x4_f8(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
53+
// expected-error@+1 {{Only 'f6E2M3FN' and 'f6E3M2FN' types are supported for conversions from f32x4 to f6x4.}}
54+
%res = nvvm.convert.f32x4.to.f6x4 %src, %rbits : vector<4xf32> -> vector<4xi8> (f8E4M3FN)
55+
llvm.return %res : vector<4xi8>
56+
}
57+
58+
// -----
59+
60+
// Test invalid destination type for f4x4 (should only accept f4E2M1FN)
61+
llvm.func @invalid_dst_type_f4x4_f6(%src : vector<4xf32>, %rbits : i32) -> i16 {
62+
// expected-error@+1 {{Only 'f4E2M1FN' type is supported for conversions from f32x4 to f4x4.}}
63+
%res = nvvm.convert.f32x4.to.f4x4 %src, %rbits : vector<4xf32> -> i16 (f6E2M3FN)
64+
llvm.return %res : i16
65+
}
66+
67+
// -----
68+
69+
// Test invalid rounding modes for non-stochastic ops
70+
llvm.func @convert_float_to_tf32_rs_not_supported(%src : f32) -> i32 {
71+
// expected-error @below {{Only {rn,rz,rna} rounding modes supported for ConvertFloatToTF32Op.}}
72+
%res = nvvm.convert.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rs>}
73+
llvm.return %res : i32
74+
}
75+
76+
// -----
77+
78+
llvm.func @convert_f32x2_to_f8x2_rs_not_supported(%a : f32, %b : f32) {
79+
// expected-error @below {{Only RN rounding mode is supported for conversions from f32x2 to 'f8E4M3FN' and 'f8E5M2' types}}
80+
%res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rs>, sat = #nvvm.sat_mode<satfinite>} : i16 (f8E4M3FN)
81+
llvm.return
82+
}
83+
84+
// -----
85+
86+
llvm.func @convert_bf16x2_to_f8x2_rs_not_supported(%src : vector<2xbf16>) {
87+
// expected-error @below {{Only RZ and RP rounding modes are supported for conversions from bf16x2 to f8x2.}}
88+
%res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xbf16> -> i16 (f8E8M0FNU)
89+
llvm.return
90+
}

0 commit comments

Comments
 (0)