-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[clang][NVPTX] Add support for mixed-precision FP arithmetic #168359
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…hmetic This change adds NVVM intrinsics and clang builtins for mixed-precision FP arithmetic instructions. Tests are added in `mixed-precision-fp.ll` and `builtins-nvptx.c` and verified through `ptxas-13.0`. PTX Spec Reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#mixed-precision-floating-point-instructions
|
@llvm/pr-subscribers-backend-nvptx @llvm/pr-subscribers-clang Author: Srinivasa Ravi (Wolfram70) ChangesThis change adds NVVM intrinsics and clang builtins for mixed-precision Tests are added in PTX Spec Reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#mixed-precision-floating-point-instructions Patch is 37.10 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/168359.diff 6 Files Affected:
diff --git a/clang/include/clang/Basic/BuiltinsNVPTX.td b/clang/include/clang/Basic/BuiltinsNVPTX.td
index d923d2a90e908..47ba12bef058c 100644
--- a/clang/include/clang/Basic/BuiltinsNVPTX.td
+++ b/clang/include/clang/Basic/BuiltinsNVPTX.td
@@ -401,6 +401,24 @@ def __nvvm_fma_rz_d : NVPTXBuiltin<"double(double, double, double)">;
def __nvvm_fma_rm_d : NVPTXBuiltin<"double(double, double, double)">;
def __nvvm_fma_rp_d : NVPTXBuiltin<"double(double, double, double)">;
+def __nvvm_fma_mixed_rn_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rz_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rm_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rp_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rn_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rz_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rm_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rp_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>;
+
+def __nvvm_fma_mixed_rn_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rz_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rm_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rp_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rn_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rz_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rm_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rp_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>;
+
// Rcp
def __nvvm_rcp_rn_ftz_f : NVPTXBuiltin<"float(float)">;
@@ -460,6 +478,52 @@ def __nvvm_add_rz_d : NVPTXBuiltin<"double(double, double)">;
def __nvvm_add_rm_d : NVPTXBuiltin<"double(double, double)">;
def __nvvm_add_rp_d : NVPTXBuiltin<"double(double, double)">;
+def __nvvm_add_mixed_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rn_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rz_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rm_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rp_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rn_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rz_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rm_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rp_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+
+def __nvvm_add_mixed_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rn_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rz_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rm_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rp_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rn_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rz_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rm_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rp_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+
+// Sub
+
+def __nvvm_sub_mixed_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rn_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rz_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rm_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rp_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rn_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rz_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rm_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rp_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+
+def __nvvm_sub_mixed_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rn_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rz_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rm_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rp_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rn_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rz_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rm_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rp_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+
// Convert
def __nvvm_d2f_rn_ftz : NVPTXBuiltin<"float(double)">;
diff --git a/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp b/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp
index 8a1cab3417d98..6f57620f0fb00 100644
--- a/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp
+++ b/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp
@@ -415,6 +415,17 @@ static Value *MakeHalfType(unsigned IntrinsicID, unsigned BuiltinID,
return MakeHalfType(CGF.CGM.getIntrinsic(IntrinsicID), BuiltinID, E, CGF);
}
+static Value *MakeMixedPrecisionFPArithmetic(unsigned IntrinsicID,
+ const CallExpr *E,
+ CodeGenFunction &CGF) {
+ SmallVector<llvm::Value *, 3> Args;
+ for (unsigned i = 0; i < E->getNumArgs(); ++i) {
+ Args.push_back(CGF.EmitScalarExpr(E->getArg(i)));
+ }
+ return CGF.Builder.CreateCall(
+ CGF.CGM.getIntrinsic(IntrinsicID, {Args[0]->getType()}), Args);
+}
+
} // namespace
Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
@@ -1197,6 +1208,118 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
return Builder.CreateCall(
CGM.getIntrinsic(Intrinsic::nvvm_barrier_cta_sync_count),
{EmitScalarExpr(E->getArg(0)), EmitScalarExpr(E->getArg(1))});
+ case NVPTX::BI__nvvm_add_mixed_f16_f32:
+ case NVPTX::BI__nvvm_add_mixed_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_add_mixed_rn_f16_f32:
+ case NVPTX::BI__nvvm_add_mixed_rn_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rn_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_add_mixed_rz_f16_f32:
+ case NVPTX::BI__nvvm_add_mixed_rz_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rz_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_add_mixed_rm_f16_f32:
+ case NVPTX::BI__nvvm_add_mixed_rm_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rm_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_add_mixed_rp_f16_f32:
+ case NVPTX::BI__nvvm_add_mixed_rp_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rp_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_add_mixed_sat_f16_f32:
+ case NVPTX::BI__nvvm_add_mixed_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_sat_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_add_mixed_rn_sat_f16_f32:
+ case NVPTX::BI__nvvm_add_mixed_rn_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rn_sat_f32,
+ E, *this);
+ case NVPTX::BI__nvvm_add_mixed_rz_sat_f16_f32:
+ case NVPTX::BI__nvvm_add_mixed_rz_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rz_sat_f32,
+ E, *this);
+ case NVPTX::BI__nvvm_add_mixed_rm_sat_f16_f32:
+ case NVPTX::BI__nvvm_add_mixed_rm_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rm_sat_f32,
+ E, *this);
+ case NVPTX::BI__nvvm_add_mixed_rp_sat_f16_f32:
+ case NVPTX::BI__nvvm_add_mixed_rp_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rp_sat_f32,
+ E, *this);
+ case NVPTX::BI__nvvm_sub_mixed_f16_f32:
+ case NVPTX::BI__nvvm_sub_mixed_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_sub_mixed_rn_f16_f32:
+ case NVPTX::BI__nvvm_sub_mixed_rn_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rn_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_sub_mixed_rz_f16_f32:
+ case NVPTX::BI__nvvm_sub_mixed_rz_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rz_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_sub_mixed_rm_f16_f32:
+ case NVPTX::BI__nvvm_sub_mixed_rm_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rm_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_sub_mixed_rp_f16_f32:
+ case NVPTX::BI__nvvm_sub_mixed_rp_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rp_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_sub_mixed_sat_f16_f32:
+ case NVPTX::BI__nvvm_sub_mixed_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_sat_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_sub_mixed_rn_sat_f16_f32:
+ case NVPTX::BI__nvvm_sub_mixed_rn_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rn_sat_f32,
+ E, *this);
+ case NVPTX::BI__nvvm_sub_mixed_rz_sat_f16_f32:
+ case NVPTX::BI__nvvm_sub_mixed_rz_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rz_sat_f32,
+ E, *this);
+ case NVPTX::BI__nvvm_sub_mixed_rm_sat_f16_f32:
+ case NVPTX::BI__nvvm_sub_mixed_rm_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rm_sat_f32,
+ E, *this);
+ case NVPTX::BI__nvvm_sub_mixed_rp_sat_f16_f32:
+ case NVPTX::BI__nvvm_sub_mixed_rp_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rp_sat_f32,
+ E, *this);
+ case NVPTX::BI__nvvm_fma_mixed_rn_f16_f32:
+ case NVPTX::BI__nvvm_fma_mixed_rn_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rn_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_fma_mixed_rz_f16_f32:
+ case NVPTX::BI__nvvm_fma_mixed_rz_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rz_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_fma_mixed_rm_f16_f32:
+ case NVPTX::BI__nvvm_fma_mixed_rm_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rm_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_fma_mixed_rp_f16_f32:
+ case NVPTX::BI__nvvm_fma_mixed_rp_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rp_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_fma_mixed_rn_sat_f16_f32:
+ case NVPTX::BI__nvvm_fma_mixed_rn_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rn_sat_f32,
+ E, *this);
+ case NVPTX::BI__nvvm_fma_mixed_rz_sat_f16_f32:
+ case NVPTX::BI__nvvm_fma_mixed_rz_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rz_sat_f32,
+ E, *this);
+ case NVPTX::BI__nvvm_fma_mixed_rm_sat_f16_f32:
+ case NVPTX::BI__nvvm_fma_mixed_rm_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rm_sat_f32,
+ E, *this);
+ case NVPTX::BI__nvvm_fma_mixed_rp_sat_f16_f32:
+ case NVPTX::BI__nvvm_fma_mixed_rp_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rp_sat_f32,
+ E, *this);
default:
return nullptr;
}
diff --git a/clang/test/CodeGen/builtins-nvptx.c b/clang/test/CodeGen/builtins-nvptx.c
index e3be262622844..1753b4c7767e9 100644
--- a/clang/test/CodeGen/builtins-nvptx.c
+++ b/clang/test/CodeGen/builtins-nvptx.c
@@ -1466,3 +1466,136 @@ __device__ void nvvm_min_max_sm86() {
#endif
// CHECK: ret void
}
+
+#define F16 (__fp16)0.1f
+#define F16_2 (__fp16)0.2f
+
+__device__ void nvvm_add_mixed_precision_sm100() {
+#if __CUDA_ARCH__ >= 1000
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_add_mixed_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rn.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_add_mixed_rn_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rz.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_add_mixed_rz_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rm.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_add_mixed_rm_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rp.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_add_mixed_rp_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.sat.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_add_mixed_sat_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rn.sat.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_add_mixed_rn_sat_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rz.sat.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_add_mixed_rz_sat_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rm.sat.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_add_mixed_rm_sat_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rp.sat.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_add_mixed_rp_sat_f16_f32(F16, 1.0f);
+
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_add_mixed_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rn.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_add_mixed_rn_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rz.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_add_mixed_rz_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rm.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_add_mixed_rm_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rp.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_add_mixed_rp_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_add_mixed_sat_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rn.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_add_mixed_rn_sat_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rz.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_add_mixed_rz_sat_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rm.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_add_mixed_rm_sat_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rp.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_add_mixed_rp_sat_bf16_f32(BF16, 1.0f);
+#endif
+}
+
+__device__ void nvvm_sub_mixed_precision_sm100() {
+#if __CUDA_ARCH__ >= 1000
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_sub_mixed_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rn.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_sub_mixed_rn_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rz.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_sub_mixed_rz_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rm.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_sub_mixed_rm_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rp.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_sub_mixed_rp_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.sat.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_sub_mixed_sat_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rn.sat.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_sub_mixed_rn_sat_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rz.sat.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_sub_mixed_rz_sat_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rm.sat.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_sub_mixed_rm_sat_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rp.sat.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_sub_mixed_rp_sat_f16_f32(F16, 1.0f);
+
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_sub_mixed_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rn.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_sub_mixed_rn_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rz.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_sub_mixed_rz_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rm.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_sub_mixed_rm_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rp.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_sub_mixed_rp_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_sub_mixed_sat_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rn.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_sub_mixed_rn_sat_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rz.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_sub_mixed_rz_sat_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rm.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ _...
[truncated]
|
🐧 Linux x64 Test Results
|
|
Ping @AlexMaclean for review. |
| def int_nvvm_fma_ # rnd # _d : NVVMBuiltin, | ||
| foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in { | ||
| foreach ftz = ["", "_ftz"] in { | ||
| foreach sat = ["", "_sat"] in { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like .sat is supported for f16 as well. Do we want to make this an overloaded intrinsic so that both types are supported?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like only rn rounding mode is supported for the half-precision instructions so only those variants with rn can be overloaded. So, I was thinking we could keep them as separate intrinsics for now.
Edit: Same for add and sub.
| foreach sat = ["", "_sat"] in { | ||
| foreach type = ["f16", "bf16"] in { | ||
| def INT_NVVM_MIXED_ADD # rnd # sat # _f32_ # type : | ||
| BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B32:$b), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the mixed FP instructions, should we also add folds for the generic fadd/fsub/fma instructions in LLVM?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes sense. I've added patterns to fold fadd, fsub, and llvm.fma.32 to the mixed precision instructions when ftz isn't present. Please take a look, thanks!
This change adds support for mixed precision floating point
arithmetic for
f16andbf16where the following patterns:are lowered to the corresponding mixed precision instructions which
combine the conversion and operation into one instruction from
sm_100onwards.This also adds the following intrinsics to complete support for
all variants of the floating point
add/sub/fmaoperations in orderto support the corresponding mixed-precision instructions:
llvm.nvvm.add.(rn/rz/rm/rp){.ftz}.sat.fllvm.nvvm.fma.(rn/rz/rm/rp){.ftz}.sat.fllvm.nvvm.sub*Tests are added in
fp-arith-sat.ll,fp-sub-intrins.ll, andbultins-nvptx.cfor the newly added intrinsics and builtins, and in
mixed-precision-fp.llfor the mixed precision instructions.
PTX spec reference for mixed precision instructions: https://docs.nvidia.com/cuda/parallel-thread-execution/#mixed-precision-floating-point-instructions