Skip to content

Commit 601f796

Browse files
authored
[MLIR][NVVM] Add missing rounding modes in fp16x2 conversions (#169005)
This change adds the `RN` and `RZ` rounding modes to the `convert.f32x2.to.f16x2` and `convert.f32x2.to.bf16x2` Ops. Tests are added in `convert_fp16x2.mlir` and `invalid_convert_fp16x2.mlir`. Tests with these Ops in `convert_stochastic_rounding.mlir` and `invalid-convert-stochastic-rounding.mlir` have been removed or modified. PTX spec reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt
1 parent e7dec23 commit 601f796

File tree

6 files changed

+294
-138
lines changed

6 files changed

+294
-138
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1998,45 +1998,57 @@ def NVVM_ConvertF4x2ToF16x2Op :
19981998

19991999
// Base class for conversions from F32x2 to FPx2 formats
20002000
// (F16x2, BF16x2)
2001-
// TODO: In separate PR, add .rn and .rz rounding variants for this conversion
2002-
// as currently only support .rs rounding mode
20032001
class NVVM_ConvertF32x2ToFPx2OpBase<string dstFormat, string mnemonic, Type dstType> :
2004-
NVVM_Op<mnemonic, [Pure, NVVMRequiresSMa<[100, 103]>]>,
2002+
NVVM_Op<mnemonic, [Pure]>,
20052003
Results<(outs dstType:$dst)>,
2006-
Arguments<(ins F32:$src_hi, F32:$src_lo, I32:$rbits,
2007-
DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::RS">:$rnd,
2004+
Arguments<(ins F32:$src_hi, F32:$src_lo,
2005+
Optional<I32>:$random_bits,
2006+
DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
20082007
DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
20092008
DefaultValuedAttr<BoolAttr, "false">:$relu)> {
2010-
let summary = "Convert two F32 values to packed " # dstFormat # " with stochastic rounding (.rs)";
2009+
let summary = "Convert two F32 values to packed " # !tolower(dstFormat) # ".";
20112010
let description = [{
2012-
Converts two F32 values to packed }] # dstFormat # [{ format using stochastic
2013-
rounding (.rs) mode with randomness provided by the `rbits` parameter. The
2014-
`relu` attribute clamps negative results to 0. The `sat` attribute determines
2015-
saturation behavior. The `src_hi` and `src_lo` parameters correspond to operands
2016-
`a` and `b` in the PTX ISA, respectively.
2011+
Converts two F32 values to packed }] # !tolower(dstFormat) # [{ format with
2012+
the specified rounding mode. The `src_hi` and `src_lo` parameters
2013+
correspond to operands `a` and `b` in the PTX ISA, respectively.
2014+
2015+
The `random_bits` parameter is required for stochastic rounding and
2016+
provides the [random bits](}] #
2017+
!if(!eq(dstFormat, "F16x2"),
2018+
"https://docs.nvidia.com/cuda/parallel-thread-execution/#cvt-rs-rbits-layout-f16",
2019+
"https://docs.nvidia.com/cuda/parallel-thread-execution/#cvt-rs-rbits-layout-bf16") #
2020+
[{) to be used for the conversion.
2021+
2022+
The `relu` attribute clamps negative results to 0.
2023+
2024+
The `sat` attribute determines saturation behavior.
20172025

20182026
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
20192027
}];
20202028

2021-
let assemblyFormat = "$src_hi `,` $src_lo `,` $rbits attr-dict `:` type($dst)";
2029+
let assemblyFormat = "$src_hi `,` $src_lo (`,` $random_bits^)? attr-dict `:` type($dst)";
20222030

20232031
let hasVerifier = 1;
20242032

20252033
let extraClassDeclaration = [{
2026-
llvm::Intrinsic::ID getIntrinsicID();
2034+
static NVVM::IDArgPair
2035+
getIntrinsicIDAndArgs(
2036+
NVVM::ConvertF32x2To}] # dstFormat # [{Op &op,
2037+
LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder);
20272038
}];
20282039

20292040
string llvmBuilder = [{
2030-
auto intId = op.getIntrinsicID();
2031-
$dst = createIntrinsicCall(builder, intId, {$src_hi, $src_lo, $rbits});
2041+
auto [intId, args] = mlir::NVVM::ConvertF32x2To}] # dstFormat #
2042+
[{Op::getIntrinsicIDAndArgs(op, moduleTranslation, builder);
2043+
$dst = createIntrinsicCall(builder, intId, args);
20322044
}];
2033-
}
2045+
}
20342046

2035-
// F32x2 -> F16x2 with stochastic rounding
2036-
def NVVM_ConvertF32x2ToF16x2Op : NVVM_ConvertF32x2ToFPx2OpBase<"f16x2", "convert.f32x2.to.f16x2", VectorOfLengthAndType<[2], [F16]>>;
2047+
// F32x2 -> F16x2
2048+
def NVVM_ConvertF32x2ToF16x2Op : NVVM_ConvertF32x2ToFPx2OpBase<"F16x2", "convert.f32x2.to.f16x2", VectorOfLengthAndType<[2], [F16]>>;
20372049

2038-
// F32x2 -> BF16x2 with stochastic rounding
2039-
def NVVM_ConvertF32x2ToBF16x2Op : NVVM_ConvertF32x2ToFPx2OpBase<"bf16x2", "convert.f32x2.to.bf16x2", VectorOfLengthAndType<[2], [BF16]>>;
2050+
// F32x2 -> BF16x2
2051+
def NVVM_ConvertF32x2ToBF16x2Op : NVVM_ConvertF32x2ToFPx2OpBase<"BF16x2", "convert.f32x2.to.bf16x2", VectorOfLengthAndType<[2], [BF16]>>;
20402052

20412053
// Base class for stochastic rounding conversions from F32x4 to FPx4 formats
20422054
// (E4M3x4, E5M2x4, E2M3x4, E3M2x4, E2M1x4)

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 123 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -452,18 +452,42 @@ LogicalResult ConvertF4x2ToF16x2Op::verify() {
452452
// Stochastic Rounding Conversion Ops
453453
//===----------------------------------------------------------------------===//
454454

455-
LogicalResult ConvertF32x2ToF16x2Op::verify() {
456-
if (getRnd() != FPRoundingMode::RS)
457-
return emitOpError("Only RS rounding mode is supported for "
458-
"conversions from f32x2 to f16x2.");
455+
static LogicalResult verifyConvertF32x2ToFP16x2Op(Twine dstType,
456+
FPRoundingMode rnd,
457+
bool hasRandomBits,
458+
Operation *op) {
459+
static constexpr FPRoundingMode validRndModes[] = {
460+
FPRoundingMode::RN, FPRoundingMode::RZ, FPRoundingMode::RS};
461+
462+
if (!llvm::is_contained(validRndModes, rnd)) {
463+
return op->emitOpError(
464+
"Only RN, RZ, and RS rounding modes are supported for "
465+
"conversions from f32x2 to ")
466+
<< dstType << ".";
467+
}
468+
469+
if (rnd == FPRoundingMode::RS) {
470+
if (!hasRandomBits) {
471+
return op->emitOpError("random_bits is required for RS rounding mode.");
472+
}
473+
} else {
474+
if (hasRandomBits) {
475+
return op->emitOpError(
476+
"random_bits not supported for RN and RZ rounding modes.");
477+
}
478+
}
479+
459480
return success();
460481
}
461482

483+
LogicalResult ConvertF32x2ToF16x2Op::verify() {
484+
return verifyConvertF32x2ToFP16x2Op("f16x2", getRnd(),
485+
getRandomBits() ? true : false, *this);
486+
}
487+
462488
LogicalResult ConvertF32x2ToBF16x2Op::verify() {
463-
if (getRnd() != FPRoundingMode::RS)
464-
return emitOpError("Only RS rounding mode is supported for "
465-
"conversions from f32x2 to bf16x2.");
466-
return success();
489+
return verifyConvertF32x2ToFP16x2Op("bf16x2", getRnd(),
490+
getRandomBits() ? true : false, *this);
467491
}
468492

469493
LogicalResult ConvertF32x4ToF8x4Op::verify() {
@@ -2921,30 +2945,100 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
29212945
return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
29222946
}()
29232947

2924-
llvm::Intrinsic::ID ConvertF32x2ToF16x2Op::getIntrinsicID() {
2925-
bool hasRelu = getRelu();
2926-
bool hasSatFinite = (getSat() == NVVM::SaturationMode::SATFINITE);
2948+
NVVM::IDArgPair
2949+
ConvertF32x2ToF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF16x2Op &op,
2950+
LLVM::ModuleTranslation &mt,
2951+
llvm::IRBuilderBase &builder) {
2952+
static constexpr llvm::Intrinsic::ID rndRNIds[] = {
2953+
llvm::Intrinsic::nvvm_ff2f16x2_rn,
2954+
llvm::Intrinsic::nvvm_ff2f16x2_rn_relu,
2955+
llvm::Intrinsic::nvvm_ff2f16x2_rn_satfinite,
2956+
llvm::Intrinsic::nvvm_ff2f16x2_rn_relu_satfinite,
2957+
};
2958+
static constexpr llvm::Intrinsic::ID rndRZIds[] = {
2959+
llvm::Intrinsic::nvvm_ff2f16x2_rz,
2960+
llvm::Intrinsic::nvvm_ff2f16x2_rz_relu,
2961+
llvm::Intrinsic::nvvm_ff2f16x2_rz_satfinite,
2962+
llvm::Intrinsic::nvvm_ff2f16x2_rz_relu_satfinite,
2963+
};
2964+
static constexpr llvm::Intrinsic::ID rndRSIds[] = {
2965+
llvm::Intrinsic::nvvm_ff2f16x2_rs,
2966+
llvm::Intrinsic::nvvm_ff2f16x2_rs_relu,
2967+
llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite,
2968+
llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite,
2969+
};
29272970

2928-
if (hasRelu && hasSatFinite)
2929-
return llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite;
2930-
if (hasRelu)
2931-
return llvm::Intrinsic::nvvm_ff2f16x2_rs_relu;
2932-
if (hasSatFinite)
2933-
return llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite;
2934-
return llvm::Intrinsic::nvvm_ff2f16x2_rs;
2971+
unsigned hasRelu = op.getRelu() ? 1 : 0;
2972+
unsigned hasSatFinite =
2973+
(op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0;
2974+
// idx: bit-0 - relu
2975+
// bit-1 - satfinite
2976+
unsigned idx = (hasSatFinite << 1) | hasRelu;
2977+
2978+
llvm::SmallVector<llvm::Value *> args;
2979+
args.push_back(mt.lookupValue(op.getSrcHi()));
2980+
args.push_back(mt.lookupValue(op.getSrcLo()));
2981+
if (op.getRandomBits())
2982+
args.push_back(mt.lookupValue(op.getRandomBits()));
2983+
2984+
switch (op.getRnd()) {
2985+
case FPRoundingMode::RN:
2986+
return {rndRNIds[idx], std::move(args)};
2987+
case FPRoundingMode::RZ:
2988+
return {rndRZIds[idx], std::move(args)};
2989+
case FPRoundingMode::RS:
2990+
return {rndRSIds[idx], std::move(args)};
2991+
default:
2992+
llvm_unreachable("Invalid rounding mode for ConvertF32x2ToF16x2Op");
2993+
}
29352994
}
29362995

2937-
llvm::Intrinsic::ID ConvertF32x2ToBF16x2Op::getIntrinsicID() {
2938-
bool hasRelu = getRelu();
2939-
bool hasSatFinite = (getSat() == NVVM::SaturationMode::SATFINITE);
2940-
2941-
if (hasRelu && hasSatFinite)
2942-
return llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite;
2943-
if (hasRelu)
2944-
return llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu;
2945-
if (hasSatFinite)
2946-
return llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite;
2947-
return llvm::Intrinsic::nvvm_ff2bf16x2_rs;
2996+
NVVM::IDArgPair
2997+
ConvertF32x2ToBF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToBF16x2Op &op,
2998+
LLVM::ModuleTranslation &mt,
2999+
llvm::IRBuilderBase &builder) {
3000+
static constexpr llvm::Intrinsic::ID rndRNIds[] = {
3001+
llvm::Intrinsic::nvvm_ff2bf16x2_rn,
3002+
llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu,
3003+
llvm::Intrinsic::nvvm_ff2bf16x2_rn_satfinite,
3004+
llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu_satfinite,
3005+
};
3006+
static constexpr llvm::Intrinsic::ID rndRZIds[] = {
3007+
llvm::Intrinsic::nvvm_ff2bf16x2_rz,
3008+
llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu,
3009+
llvm::Intrinsic::nvvm_ff2bf16x2_rz_satfinite,
3010+
llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu_satfinite,
3011+
};
3012+
static constexpr llvm::Intrinsic::ID rndRSIds[] = {
3013+
llvm::Intrinsic::nvvm_ff2bf16x2_rs,
3014+
llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu,
3015+
llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite,
3016+
llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite,
3017+
};
3018+
3019+
unsigned hasRelu = op.getRelu() ? 1 : 0;
3020+
unsigned hasSatFinite =
3021+
(op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0;
3022+
// idx: bit-0 - relu
3023+
// bit-1 - satfinite
3024+
unsigned idx = (hasSatFinite << 1) | hasRelu;
3025+
3026+
llvm::SmallVector<llvm::Value *> args;
3027+
args.push_back(mt.lookupValue(op.getSrcHi()));
3028+
args.push_back(mt.lookupValue(op.getSrcLo()));
3029+
if (op.getRandomBits())
3030+
args.push_back(mt.lookupValue(op.getRandomBits()));
3031+
3032+
switch (op.getRnd()) {
3033+
case FPRoundingMode::RN:
3034+
return {rndRNIds[idx], std::move(args)};
3035+
case FPRoundingMode::RZ:
3036+
return {rndRZIds[idx], std::move(args)};
3037+
case FPRoundingMode::RS:
3038+
return {rndRSIds[idx], std::move(args)};
3039+
default:
3040+
llvm_unreachable("Invalid rounding mode for ConvertF32x2ToBF16x2Op");
3041+
}
29483042
}
29493043

29503044
llvm::Intrinsic::ID ConvertF32x4ToF8x4Op::getIntrinsicID() {

mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,35 +2,15 @@
22

33
// Test invalid target architecture (sm_100 instead of sm_100a)
44
gpu.module @invalid_arch_sm_100 [#nvvm.target<chip = "sm_100">] {
5-
func.func @convert_rs() {
6-
%f1 = llvm.mlir.constant(1.0 : f32) : f32
7-
%f2 = llvm.mlir.constant(2.0 : f32) : f32
8-
%rbits = llvm.mlir.constant(0x12345678 : i32) : i32
9-
// expected-error@+1 {{'nvvm.convert.f32x2.to.f16x2' op is not supported on sm_100}}
10-
%res = nvvm.convert.f32x2.to.f16x2 %f1, %f2, %rbits : vector<2xf16>
5+
func.func @convert_rs(%src : vector<4xf32>, %rbits : i32) {
6+
// expected-error@+1 {{'nvvm.convert.f32x4.to.f8x4' op is not supported on sm_100}}
7+
%res = nvvm.convert.f32x4.to.f8x4 %src, %rbits : vector<4xf32> -> vector<4xi8> (f8E4M3FN)
118
return
129
}
1310
}
1411

1512
// -----
1613

17-
// Test that operations require stochastic rounding mode
18-
llvm.func @invalid_rnd_mode_f16x2(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xf16> {
19-
// expected-error@+1 {{Only RS rounding mode is supported for conversions from f32x2 to f16x2.}}
20-
%res = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xf16>
21-
llvm.return %res : vector<2xf16>
22-
}
23-
24-
// -----
25-
26-
llvm.func @invalid_rnd_mode_bf16x2(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xbf16> {
27-
// expected-error@+1 {{Only RS rounding mode is supported for conversions from f32x2 to bf16x2.}}
28-
%res = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xbf16>
29-
llvm.return %res : vector<2xbf16>
30-
}
31-
32-
// -----
33-
3414
// Test invalid destination types for f8x4 (should only accept f8E4M3FN, f8E5M2)
3515
llvm.func @invalid_dst_type_f8x4_e3m4(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
3616
// expected-error@+1 {{Only 'f8E4M3FN' and 'f8E5M2' types are supported for conversions from f32x4 to f8x4.}}
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
2+
3+
// CHECK-LABEL: @convert_f32x2_to_f16x2_rn
4+
llvm.func @convert_f32x2_to_f16x2_rn(%srcA : f32, %srcB : f32) {
5+
// CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rn(float %{{.*}}, float %{{.*}})
6+
%res1 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xf16>
7+
// CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rn.satfinite(float %{{.*}}, float %{{.*}})
8+
%res2 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>
9+
// CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rn.relu(float %{{.*}}, float %{{.*}})
10+
%res3 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, relu = true} : vector<2xf16>
11+
// CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rn.relu.satfinite(float %{{.*}}, float %{{.*}})
12+
%res4 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>
13+
14+
llvm.return
15+
}
16+
17+
// CHECK-LABEL: @convert_f32x2_to_f16x2_rz
18+
llvm.func @convert_f32x2_to_f16x2_rz(%srcA : f32, %srcB : f32) {
19+
// CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rz(float %{{.*}}, float %{{.*}})
20+
%res1 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xf16>
21+
// CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rz.satfinite(float %{{.*}}, float %{{.*}})
22+
%res2 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>
23+
// CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rz.relu(float %{{.*}}, float %{{.*}})
24+
%res3 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, relu = true} : vector<2xf16>
25+
// CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rz.relu.satfinite(float %{{.*}}, float %{{.*}})
26+
%res4 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>
27+
28+
llvm.return
29+
}
30+
31+
// CHECK-LABEL: @convert_f32x2_to_f16x2_rs_stochastic
32+
llvm.func @convert_f32x2_to_f16x2_rs_stochastic(%srcA : f32, %srcB : f32, %rbits : i32) {
33+
// CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
34+
%res1 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xf16>
35+
// CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
36+
%res2 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {relu = true, rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xf16>
37+
// CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
38+
%res3 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rs>, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>
39+
// CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
40+
%res4 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {relu = true, rnd = #nvvm.fp_rnd_mode<rs>, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>
41+
42+
llvm.return
43+
}
44+
45+
// -----
46+
47+
// CHECK-LABEL: @convert_f32x2_to_bf16x2_rn
48+
llvm.func @convert_f32x2_to_bf16x2_rn(%srcA : f32, %srcB : f32) {
49+
// CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn(float %{{.*}}, float %{{.*}})
50+
%res1 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xbf16>
51+
// CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.satfinite(float %{{.*}}, float %{{.*}})
52+
%res2 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>
53+
// CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.relu(float %{{.*}}, float %{{.*}})
54+
%res3 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, relu = true} : vector<2xbf16>
55+
// CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.relu.satfinite(float %{{.*}}, float %{{.*}})
56+
%res4 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>
57+
58+
llvm.return
59+
}
60+
61+
// CHECK-LABEL: @convert_f32x2_to_bf16x2_rz
62+
llvm.func @convert_f32x2_to_bf16x2_rz(%srcA : f32, %srcB : f32) {
63+
// CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz(float %{{.*}}, float %{{.*}})
64+
%res1 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xbf16>
65+
// CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.satfinite(float %{{.*}}, float %{{.*}})
66+
%res2 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>
67+
// CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.relu(float %{{.*}}, float %{{.*}})
68+
%res3 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, relu = true} : vector<2xbf16>
69+
// CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.relu.satfinite(float %{{.*}}, float %{{.*}})
70+
%res4 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>
71+
72+
llvm.return
73+
}
74+
75+
// CHECK-LABEL: @convert_f32x2_to_bf16x2_rs_stochastic
76+
llvm.func @convert_f32x2_to_bf16x2_rs_stochastic(%srcA : f32, %srcB : f32, %rbits : i32) {
77+
// CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
78+
%res1 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xbf16>
79+
// CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
80+
%res2 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {relu = true, rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xbf16>
81+
// CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
82+
%res3 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rs>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>
83+
// CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
84+
%res4 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {relu = true, rnd = #nvvm.fp_rnd_mode<rs>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>
85+
86+
llvm.return
87+
}

0 commit comments

Comments
 (0)