Skip to content

Commit b435f83

Browse files
committed
[NVPTX] Add convert float to tf32 intrinsics
This patch adds an intrinsic to convert float to tf32. * This intrinsic uses flags for rounding and saturation modes as well as relu. The backend looks through these flags and lowers to the appropriate instruction. * Docs are updated to describe the usage of the flag arguments. * Lit tests are added for all the combinations. Note: We already have an intrinsic 'llvm.nvvm.f2tf32.rna' which caters only to one variant of the PTX instruction. Once this change lands, I will submit a follow-up PR to auto-upgrade it to use the generic variant. PTX Spec link: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt Signed-off-by: Durgadoss R <[email protected]>
1 parent 1849244 commit b435f83

File tree

12 files changed

+299
-0
lines changed

12 files changed

+299
-0
lines changed

llvm/docs/NVPTXUsage.rst

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,66 @@ to left-shift the found bit into the most-significant bit position, otherwise
462462
the result is the shift amount needed to right-shift the found bit into the
463463
least-significant bit position. 0xffffffff is returned if no 1 bit is found.
464464

465+
Conversion Intrinsics (for cvt.* PTX instructions)
466+
--------------------------------------------------
467+
468+
'``llvm.nvvm.cvt.float.to.tf32``'
469+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
470+
471+
Syntax:
472+
"""""""
473+
474+
.. code-block:: llvm
475+
476+
declare i32 @llvm.nvvm.cvt.float.to.tf32(float %f1, i8 %flag_fp_rnd_mode, i8 %flag_sat_mode, i1 %flag_relu)
477+
478+
Overview:
479+
"""""""""
480+
481+
The '``@llvm.nvvm.cvt.float.to.tf32``' intrinsic lowers to
482+
the ``cvt.*.tf32.f32`` set of PTX instructions.
483+
484+
* The first argument is the input float to be converted to TF32.
485+
This is followed by three flag arguments encoding the rounding mode,
486+
saturation mode, and the relu modifier respectively.
487+
488+
* The second argument (denoted by ``i8 %flag_fp_rnd_mode``) denotes
489+
the floating-point rounding modes supported for this instruction.
490+
This must be a compile-time constant and the encoding is as below:
491+
492+
========== ==============
493+
Enum Value Rounding Mode
494+
========== ==============
495+
``0`` NONE
496+
``1`` ROUND_RZ
497+
``2`` ROUND_RN
498+
``3`` ROUND_RP
499+
``4`` ROUND_RM
500+
``5`` ROUND_RNA
501+
========== ==============
502+
503+
The valid rounding modes are ``RNA, RN and RZ``.
504+
505+
* The third argument (denoted by ``i8 %flag_sat_mode``) denotes the
506+
saturation modifier for this intrinsic. As of now, it can either
507+
be None or Satfinite, according to the enumeration below:
508+
509+
========== ================
510+
Enum Value Saturation Mode
511+
========== ================
512+
``0`` NONE
513+
``1`` SATFINITE
514+
========== ================
515+
516+
* The last argument (denoted by ``i1 %flag_relu``) when set, generates
517+
the ``.relu`` variant of the instruction.
518+
519+
* Invalid values for the compile-time flag arguments may lead
520+
to error(s) during Codegen.
521+
522+
For more information, refer PTX ISA
523+
`<https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt>`_.
524+
465525
TMA family of Intrinsics
466526
------------------------
467527

llvm/include/llvm/IR/IntrinsicsNVVM.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1466,6 +1466,15 @@ let TargetPrefix = "nvvm" in {
14661466
def int_nvvm_e5m2x2_to_f16x2_rn_relu : ClangBuiltin<"__nvvm_e5m2x2_to_f16x2_rn_relu">,
14671467
Intrinsic<[llvm_v2f16_ty], [llvm_i16_ty], [IntrNoMem, IntrNoCallback]>;
14681468

1469+
// Convert Float to TF32
1470+
def int_nvvm_cvt_float_to_tf32 : Intrinsic<[llvm_i32_ty],
1471+
[llvm_float_ty, // Input float
1472+
llvm_i8_ty, // Flag for Rounding Modes
1473+
llvm_i8_ty, // Flag for Saturation Modes
1474+
llvm_i1_ty], // Flag for relu
1475+
[IntrNoMem, IntrNoCallback,
1476+
ImmArg<ArgIndex<1>>, ImmArg<ArgIndex<2>>, ImmArg<ArgIndex<3>>]>;
1477+
14691478
// FNS
14701479

14711480
def int_nvvm_fns : ClangBuiltin<"__nvvm_fns">,

llvm/include/llvm/IR/NVVMIntrinsicFlags.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,22 @@ enum class TMAReductionOp : uint8_t {
3434
XOR = 7,
3535
};
3636

37+
// Rounding Modes for floating point types
38+
enum class FPRoundingMode : uint8_t {
39+
NONE = 0,
40+
ROUND_RZ = 1, // roundTowardZero
41+
ROUND_RN = 2, // roundToNearest-TiesToEven
42+
ROUND_RP = 3, // roundTowardPositiveInf
43+
ROUND_RM = 4, // roundTowardNegativeInf
44+
ROUND_RNA = 5, // roundToNearest-TiesAwayFromZero
45+
};
46+
47+
// Saturation Modes
48+
enum class SaturationMode : uint8_t {
49+
NONE = 0,
50+
SATFINITE = 1,
51+
};
52+
3753
} // namespace nvvm
3854
} // namespace llvm
3955
#endif // LLVM_IR_NVVMINTRINSICFLAGS_H

llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,3 +453,56 @@ void NVPTXInstPrinter::printTmaReductionMode(const MCInst *MI, int OpNum,
453453
llvm_unreachable(
454454
"Invalid Reduction Op in printCpAsyncBulkTensorReductionMode");
455455
}
456+
457+
void NVPTXInstPrinter::printFPRoundingMode(const MCInst *MI, int OpNum,
458+
raw_ostream &O,
459+
const char *Modifier) {
460+
const MCOperand &MO = MI->getOperand(OpNum);
461+
using Mode = nvvm::FPRoundingMode;
462+
463+
switch (static_cast<Mode>(MO.getImm())) {
464+
case Mode::NONE:
465+
O << "";
466+
return;
467+
case Mode::ROUND_RN:
468+
O << ".rn";
469+
return;
470+
case Mode::ROUND_RNA:
471+
O << ".rna";
472+
return;
473+
case Mode::ROUND_RZ:
474+
O << ".rz";
475+
return;
476+
case Mode::ROUND_RP:
477+
O << ".rp";
478+
return;
479+
case Mode::ROUND_RM:
480+
O << ".rm";
481+
return;
482+
}
483+
llvm_unreachable("Invalid mode in printFPRoundingMode");
484+
}
485+
486+
void NVPTXInstPrinter::printSaturationMode(const MCInst *MI, int OpNum,
487+
raw_ostream &O,
488+
const char *Modifier) {
489+
const MCOperand &MO = MI->getOperand(OpNum);
490+
using Mode = nvvm::SaturationMode;
491+
492+
switch (static_cast<Mode>(MO.getImm())) {
493+
case Mode::NONE:
494+
O << "";
495+
return;
496+
case Mode::SATFINITE:
497+
O << ".satfinite";
498+
return;
499+
}
500+
llvm_unreachable("Invalid mode in printSaturationMode");
501+
}
502+
503+
void NVPTXInstPrinter::printReluModifier(const MCInst *MI, int OpNum,
504+
raw_ostream &O, const char *Modifier) {
505+
const MCOperand &MO = MI->getOperand(OpNum);
506+
if (MO.getImm())
507+
O << ".relu";
508+
}

llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,12 @@ class NVPTXInstPrinter : public MCInstPrinter {
5656
const char *Modifier = nullptr);
5757
void printTmaReductionMode(const MCInst *MI, int OpNum, raw_ostream &O,
5858
const char *Modifier = nullptr);
59+
void printFPRoundingMode(const MCInst *MI, int OpNum, raw_ostream &O,
60+
const char *Modifier = nullptr);
61+
void printSaturationMode(const MCInst *MI, int OpNum, raw_ostream &O,
62+
const char *Modifier = nullptr);
63+
void printReluModifier(const MCInst *MI, int OpNum, raw_ostream &O,
64+
const char *Modifier = nullptr);
5965
};
6066

6167
}

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -728,7 +728,53 @@ bool NVPTXDAGToDAGISel::tryIntrinsicNoChain(SDNode *N) {
728728
case Intrinsic::nvvm_texsurf_handle_internal:
729729
SelectTexSurfHandle(N);
730730
return true;
731+
case Intrinsic::nvvm_cvt_float_to_tf32:
732+
SelectCvtFloatToTF32(N);
733+
return true;
734+
}
735+
}
736+
737+
void NVPTXDAGToDAGISel::SelectCvtFloatToTF32(SDNode *N) {
738+
// 0 - IID
739+
// 1 - Input Float
740+
// 2 - Rounding Mode
741+
// 3 - Saturation Mode
742+
// 4 - Relu Flag
743+
uint64_t Rnd = N->getConstantOperandVal(2);
744+
uint64_t Sat = N->getConstantOperandVal(3);
745+
bool IsRelu = N->getConstantOperandVal(4) == 1;
746+
747+
if (!Subtarget->hasTF32Math())
748+
report_fatal_error("TF32 destination format requires at least sm80");
749+
750+
using SatMode = nvvm::SaturationMode;
751+
bool IsSatFinite = static_cast<SatMode>(Sat) == SatMode::SATFINITE;
752+
if (IsSatFinite && Subtarget->getPTXVersion() < 81)
753+
report_fatal_error("satfinite modifier requires PTX version 8.1 or higher");
754+
755+
using RndMode = nvvm::FPRoundingMode;
756+
switch (static_cast<RndMode>(Rnd)) {
757+
case RndMode::ROUND_RNA:
758+
if (IsRelu)
759+
report_fatal_error("relu not supported with rna rounding mode");
760+
break;
761+
case RndMode::ROUND_RN:
762+
case RndMode::ROUND_RZ: {
763+
if (Subtarget->getSmVersion() < 90)
764+
report_fatal_error("rn/rz rounding modes require at least sm90");
765+
if (IsSatFinite)
766+
report_fatal_error("satfinite not supported with rn/rz rounding modes");
767+
break;
768+
}
769+
default:
770+
report_fatal_error("Invalid FP rounding mode in SelectCvtFloatToTF32");
731771
}
772+
773+
SDLoc DL(N);
774+
SDValue Ops[] = {N->getOperand(1), getI32Imm(Rnd, DL), getI32Imm(Sat, DL),
775+
getI32Imm(IsRelu, DL)};
776+
ReplaceNode(N, CurDAG->getMachineNode(NVPTX::cvt_float_to_tf32, DL,
777+
N->getVTList(), Ops));
732778
}
733779

734780
void NVPTXDAGToDAGISel::SelectTexSurfHandle(SDNode *N) {

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
7373
bool tryIntrinsicChain(SDNode *N);
7474
bool tryIntrinsicVoid(SDNode *N);
7575
void SelectTexSurfHandle(SDNode *N);
76+
void SelectCvtFloatToTF32(SDNode *N);
7677
bool tryLoad(SDNode *N);
7778
bool tryLoadVector(SDNode *N);
7879
bool tryLDGLDU(SDNode *N);

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1802,6 +1802,22 @@ def : Pat<(int_nvvm_e5m2x2_to_f16x2_rn Int16Regs:$a),
18021802
def : Pat<(int_nvvm_e5m2x2_to_f16x2_rn_relu Int16Regs:$a),
18031803
(CVT_f16x2_e5m2x2 $a, CvtRN_RELU)>;
18041804

1805+
def FPRoundingMode : Operand<i32> {
1806+
let PrintMethod = "printFPRoundingMode";
1807+
}
1808+
1809+
def SatMode : Operand<i32> {
1810+
let PrintMethod = "printSaturationMode";
1811+
}
1812+
1813+
def ReluFlag : Operand<i32> {
1814+
let PrintMethod = "printReluModifier";
1815+
}
1816+
1817+
def cvt_float_to_tf32 : NVPTXInst<(outs Int32Regs:$dest),
1818+
(ins Float32Regs:$a, FPRoundingMode:$rnd, SatMode:$sat, ReluFlag:$relu),
1819+
"cvt${rnd:rnd}${sat:sat}${relu:relu}.tf32.f32 \t$dest, $a;", []>;
1820+
18051821
//
18061822
// FNS
18071823
//

llvm/lib/Target/NVPTX/NVPTXSubtarget.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ class NVPTXSubtarget : public NVPTXGenSubtargetInfo {
8383
bool hasFP16Math() const { return SmVersion >= 53; }
8484
bool hasBF16Math() const { return SmVersion >= 80; }
8585
bool allowFP16Math() const;
86+
bool hasTF32Math() const { return SmVersion >= 80 && PTXVersion >= 70; }
8687
bool hasMaskOperator() const { return PTXVersion >= 71; }
8788
bool hasNoReturn() const { return SmVersion >= 30 && PTXVersion >= 64; }
8889
// Does SM & PTX support memory orderings (weak and atomic: relaxed, acquire,

llvm/test/CodeGen/NVPTX/convert-sm80.ll

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,3 +261,20 @@ define <2 x half> @fold_ff2f16x2(float %lo, float %hi) {
261261
%v1 = insertelement <2 x half> %v0, half %hih, i64 1
262262
ret <2 x half> %v1
263263
}
264+
265+
declare i32 @llvm.nvvm.cvt.float.to.tf32(float %f1, i8, i8, i1)
266+
267+
define i32 @cvt_rna_tf32_f32_flags(float %f1) {
268+
; CHECK-LABEL: cvt_rna_tf32_f32_flags(
269+
; CHECK: {
270+
; CHECK-NEXT: .reg .b32 %r<2>;
271+
; CHECK-NEXT: .reg .f32 %f<2>;
272+
; CHECK-EMPTY:
273+
; CHECK-NEXT: // %bb.0:
274+
; CHECK-NEXT: ld.param.f32 %f1, [cvt_rna_tf32_f32_flags_param_0];
275+
; CHECK-NEXT: cvt.rna.tf32.f32 %r1, %f1;
276+
; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
277+
; CHECK-NEXT: ret;
278+
%val = call i32 @llvm.nvvm.cvt.float.to.tf32(float %f1, i8 5, i8 0, i1 0)
279+
ret i32 %val
280+
}

0 commit comments

Comments
 (0)