diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 04042903e343e..bf3131932a56b 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -970,6 +970,77 @@ def NVVM_CpAsyncMBarrierArriveSharedOp : NVVM_Op<"cp.async.mbarrier.arrive.share }]; } +//===----------------------------------------------------------------------===// +// NVVM Conversion Ops (for "cvt.*" family of PTX instructions) +//===----------------------------------------------------------------------===// + +// Attributes for the floating point rounding modes supported by PTX +def FPRoundingModeNone : I32EnumAttrCase<"NONE", 0, "none">; +def FPRoundingModeRN : I32EnumAttrCase<"RN", 1, "rn">; +def FPRoundingModeRM : I32EnumAttrCase<"RM", 2, "rm">; +def FPRoundingModeRP : I32EnumAttrCase<"RP", 3, "rp">; +def FPRoundingModeRZ : I32EnumAttrCase<"RZ", 4, "rz">; +def FPRoundingModeRNA : I32EnumAttrCase<"RNA", 5, "rna">; + +def FPRoundingMode : I32EnumAttr<"FPRoundingMode", "NVVM FPRoundingMode kind", + [FPRoundingModeNone, FPRoundingModeRN, FPRoundingModeRM, + FPRoundingModeRP, FPRoundingModeRZ, FPRoundingModeRNA]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::NVVM"; +} +def FPRoundingModeAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +def SaturationModeNone : I32EnumAttrCase<"NONE", 0, "none">; +def SaturationModeFinite : I32EnumAttrCase<"SATFINITE", 1, "satfinite">; + +def SaturationMode : I32EnumAttr<"SaturationMode", "NVVM SaturationMode kind", + [SaturationModeNone, SaturationModeFinite]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::NVVM"; +} +def SaturationModeAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +def NVVM_CvtFloatToTF32Op : NVVM_Op<"cvt.float.to.tf32"> { + let summary = "Convert the given float input to TF32"; + let description = [{ + This Op converts the given f32 input to tf32. + The result `res` is represented as an i32 type. + The `relu` attribute, when set, lowers to the '.relu' variant of + the cvt instruction. The `rnd` and `sat` attributes specify the + the rounding and saturation modes respectively. + [For more information, see PTX ISA] + (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt) + }]; + + let hasVerifier = 1; + let results = (outs I32:$res); + let arguments = (ins + F32:$src, + DefaultValuedAttr:$rnd, + DefaultValuedAttr:$sat, + DefaultValuedAttr:$relu); + + let assemblyFormat = "$src attr-dict"; + + let extraClassDeclaration = [{ + static llvm::Intrinsic::ID getIntrinsicID(NVVM::FPRoundingMode, + NVVM::SaturationMode, + bool hasRelu); + }]; + + string llvmBuilder = [{ + auto intId = NVVM::CvtFloatToTF32Op::getIntrinsicID($rnd, $sat, $relu); + $res = createIntrinsicCall(builder, intId, {$src}); + }]; +} + +//===----------------------------------------------------------------------===// +// NVVM MMA Ops +//===----------------------------------------------------------------------===// /// Helpers to instantiate different version of wmma intrinsics. /// This matches the hierarchy used in IntrinsicsNVVM.td to define all the /// combinations of the intrinsics. diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index d8fde3e765ac4..ccb5ad05f0bf7 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -138,6 +138,26 @@ LogicalResult CpAsyncBulkTensorReduceOp::verify() { getLoc()); } +LogicalResult CvtFloatToTF32Op::verify() { + using RndMode = NVVM::FPRoundingMode; + switch (getRnd()) { + case RndMode::RNA: + if (getRelu()) + return emitError("Relu not supported with rna rounding mode."); + break; + case RndMode::RN: + case RndMode::RZ: + if (getSat() != NVVM::SaturationMode::NONE) + return emitError( + "Saturation mode not supported with rn/rz rounding modes."); + break; + default: + return emitError( + "Only {rn,rz,rna} rounding modes supported for CvtFloatToTF32Op."); + } + return success(); +} + // Given the element type of an operand and whether or not it is an accumulator, // this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the // operand's element type. @@ -1163,6 +1183,26 @@ llvm::Intrinsic::ID CpAsyncBulkTensorReduceOp::getIntrinsicID( llvm_unreachable("Invalid Reduction Op for CpAsyncBulkTensorReduceOp"); } +llvm::Intrinsic::ID CvtFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd, + NVVM::SaturationMode sat, + bool hasRelu) { + using RndMode = NVVM::FPRoundingMode; + switch (rnd) { + case RndMode::RN: + return hasRelu ? llvm::Intrinsic::nvvm_f2tf32_rn_relu + : llvm::Intrinsic::nvvm_f2tf32_rn; + case RndMode::RZ: + return hasRelu ? llvm::Intrinsic::nvvm_f2tf32_rz_relu + : llvm::Intrinsic::nvvm_f2tf32_rz; + case RndMode::RNA: + return (sat == NVVM::SaturationMode::SATFINITE) + ? llvm::Intrinsic::nvvm_f2tf32_rna_satfinite + : llvm::Intrinsic::nvvm_f2tf32_rna; + default: + llvm_unreachable("Invalid RoundingMode for CvtFloatToTF32Op"); + } +} + /// Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might /// have ConstantRangeAttr. static void nvvmInferResultRanges(Operation *op, Value result, diff --git a/mlir/test/Target/LLVMIR/nvvm/cvt_tf32.mlir b/mlir/test/Target/LLVMIR/nvvm/cvt_tf32.mlir new file mode 100644 index 0000000000000..90a232e4baac6 --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/cvt_tf32.mlir @@ -0,0 +1,43 @@ +// RUN: mlir-translate -mlir-to-llvmir %s -split-input-file --verify-diagnostics | FileCheck %s + +// CHECK-LABEL: @convert_float_to_tf32_rna +llvm.func @convert_float_to_tf32_rna(%src : f32) -> i32 { + // CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rna(float %{{.*}}) + %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode} + llvm.return %res : i32 +} + +// CHECK-LABEL: @convert_float_to_tf32_rna_sf +llvm.func @convert_float_to_tf32_rna_sf(%src : f32) -> i32 { + // CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rna.satfinite(float %{{.*}}) + %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} + llvm.return %res : i32 +} + +// CHECK-LABEL: @convert_float_to_tf32_rn +llvm.func @convert_float_to_tf32_rn(%src : f32) -> i32 { + // CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rn(float %{{.*}}) + %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode} + llvm.return %res : i32 +} + +// CHECK-LABEL: @convert_float_to_tf32_rn_relu +llvm.func @convert_float_to_tf32_rn_relu(%src : f32) -> i32 { + // CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rn.relu(float %{{.*}}) + %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode, relu=true} + llvm.return %res : i32 +} + +// CHECK-LABEL: @convert_float_to_tf32_rz +llvm.func @convert_float_to_tf32_rz(%src : f32) -> i32 { + // CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rz(float %{{.*}}) + %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode} + llvm.return %res : i32 +} + +// CHECK-LABEL: @convert_float_to_tf32_rz_relu +llvm.func @convert_float_to_tf32_rz_relu(%src : f32) -> i32 { + // CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rz.relu(float %{{.*}}) + %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode, relu=true} + llvm.return %res : i32 +} diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir index 44c7126255dc4..cb08064590bc3 100644 --- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir @@ -106,3 +106,35 @@ llvm.func @tma_reduce_2d_im2col(%src : !llvm.ptr<3>, %tma_desc : !llvm.ptr, %d0 nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1] {redKind = #nvvm.tma_redux_kind, mode = #nvvm.tma_store_mode}: !llvm.ptr, !llvm.ptr<3> llvm.return } + +// ----- + +llvm.func @convert_float_to_tf32_rna_relu(%src : f32) -> i32 { + // expected-error @below {{Relu not supported with rna rounding mode.}} + %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode, relu=true} + llvm.return %res : i32 +} + +// ----- + +llvm.func @convert_float_to_tf32_rn_sf(%src : f32) -> i32 { + // expected-error @below {{Saturation mode not supported with rn/rz rounding modes.}} + %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} + llvm.return %res : i32 +} + +// ----- + +llvm.func @convert_float_to_tf32_rz_sf(%src : f32) -> i32 { + // expected-error @below {{Saturation mode not supported with rn/rz rounding modes.}} + %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} + llvm.return %res : i32 +} + +// ----- + +llvm.func @convert_float_to_tf32_no_rnd_mode(%src : f32) -> i32 { + // expected-error @below {{Only {rn,rz,rna} rounding modes supported for CvtFloatToTF32Op.}} + %res = nvvm.cvt.float.to.tf32 %src + llvm.return %res : i32 +}