Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions clang/include/clang/Basic/BuiltinsNVPTX.td
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,24 @@ def __nvvm_fma_rz_d : NVPTXBuiltin<"double(double, double, double)">;
def __nvvm_fma_rm_d : NVPTXBuiltin<"double(double, double, double)">;
def __nvvm_fma_rp_d : NVPTXBuiltin<"double(double, double, double)">;

def __nvvm_fma_mixed_rn_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>;
def __nvvm_fma_mixed_rz_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>;
def __nvvm_fma_mixed_rm_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>;
def __nvvm_fma_mixed_rp_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>;
def __nvvm_fma_mixed_rn_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>;
def __nvvm_fma_mixed_rz_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>;
def __nvvm_fma_mixed_rm_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>;
def __nvvm_fma_mixed_rp_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>;

def __nvvm_fma_mixed_rn_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>;
def __nvvm_fma_mixed_rz_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>;
def __nvvm_fma_mixed_rm_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>;
def __nvvm_fma_mixed_rp_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>;
def __nvvm_fma_mixed_rn_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>;
def __nvvm_fma_mixed_rz_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>;
def __nvvm_fma_mixed_rm_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>;
def __nvvm_fma_mixed_rp_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>;

// Rcp

def __nvvm_rcp_rn_ftz_f : NVPTXBuiltin<"float(float)">;
Expand Down Expand Up @@ -460,6 +478,52 @@ def __nvvm_add_rz_d : NVPTXBuiltin<"double(double, double)">;
def __nvvm_add_rm_d : NVPTXBuiltin<"double(double, double)">;
def __nvvm_add_rp_d : NVPTXBuiltin<"double(double, double)">;

def __nvvm_add_mixed_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
def __nvvm_add_mixed_rn_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
def __nvvm_add_mixed_rz_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
def __nvvm_add_mixed_rm_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
def __nvvm_add_mixed_rp_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
def __nvvm_add_mixed_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
def __nvvm_add_mixed_rn_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
def __nvvm_add_mixed_rz_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
def __nvvm_add_mixed_rm_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
def __nvvm_add_mixed_rp_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;

def __nvvm_add_mixed_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
def __nvvm_add_mixed_rn_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
def __nvvm_add_mixed_rz_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
def __nvvm_add_mixed_rm_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
def __nvvm_add_mixed_rp_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
def __nvvm_add_mixed_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
def __nvvm_add_mixed_rn_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
def __nvvm_add_mixed_rz_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
def __nvvm_add_mixed_rm_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
def __nvvm_add_mixed_rp_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;

// Sub

def __nvvm_sub_mixed_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
def __nvvm_sub_mixed_rn_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
def __nvvm_sub_mixed_rz_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
def __nvvm_sub_mixed_rm_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
def __nvvm_sub_mixed_rp_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
def __nvvm_sub_mixed_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
def __nvvm_sub_mixed_rn_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
def __nvvm_sub_mixed_rz_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
def __nvvm_sub_mixed_rm_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
def __nvvm_sub_mixed_rp_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;

def __nvvm_sub_mixed_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
def __nvvm_sub_mixed_rn_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
def __nvvm_sub_mixed_rz_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
def __nvvm_sub_mixed_rm_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
def __nvvm_sub_mixed_rp_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
def __nvvm_sub_mixed_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
def __nvvm_sub_mixed_rn_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
def __nvvm_sub_mixed_rz_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
def __nvvm_sub_mixed_rm_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
def __nvvm_sub_mixed_rp_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;

// Convert

def __nvvm_d2f_rn_ftz : NVPTXBuiltin<"float(double)">;
Expand Down
123 changes: 123 additions & 0 deletions clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,17 @@ static Value *MakeHalfType(unsigned IntrinsicID, unsigned BuiltinID,
return MakeHalfType(CGF.CGM.getIntrinsic(IntrinsicID), BuiltinID, E, CGF);
}

static Value *MakeMixedPrecisionFPArithmetic(unsigned IntrinsicID,
const CallExpr *E,
CodeGenFunction &CGF) {
SmallVector<llvm::Value *, 3> Args;
for (unsigned i = 0; i < E->getNumArgs(); ++i) {
Args.push_back(CGF.EmitScalarExpr(E->getArg(i)));
}
return CGF.Builder.CreateCall(
CGF.CGM.getIntrinsic(IntrinsicID, {Args[0]->getType()}), Args);
}

} // namespace

Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
Expand Down Expand Up @@ -1197,6 +1208,118 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
return Builder.CreateCall(
CGM.getIntrinsic(Intrinsic::nvvm_barrier_cta_sync_count),
{EmitScalarExpr(E->getArg(0)), EmitScalarExpr(E->getArg(1))});
case NVPTX::BI__nvvm_add_mixed_f16_f32:
case NVPTX::BI__nvvm_add_mixed_bf16_f32:
return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_f32, E,
*this);
case NVPTX::BI__nvvm_add_mixed_rn_f16_f32:
case NVPTX::BI__nvvm_add_mixed_rn_bf16_f32:
return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rn_f32, E,
*this);
case NVPTX::BI__nvvm_add_mixed_rz_f16_f32:
case NVPTX::BI__nvvm_add_mixed_rz_bf16_f32:
return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rz_f32, E,
*this);
case NVPTX::BI__nvvm_add_mixed_rm_f16_f32:
case NVPTX::BI__nvvm_add_mixed_rm_bf16_f32:
return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rm_f32, E,
*this);
case NVPTX::BI__nvvm_add_mixed_rp_f16_f32:
case NVPTX::BI__nvvm_add_mixed_rp_bf16_f32:
return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rp_f32, E,
*this);
case NVPTX::BI__nvvm_add_mixed_sat_f16_f32:
case NVPTX::BI__nvvm_add_mixed_sat_bf16_f32:
return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_sat_f32, E,
*this);
case NVPTX::BI__nvvm_add_mixed_rn_sat_f16_f32:
case NVPTX::BI__nvvm_add_mixed_rn_sat_bf16_f32:
return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rn_sat_f32,
E, *this);
case NVPTX::BI__nvvm_add_mixed_rz_sat_f16_f32:
case NVPTX::BI__nvvm_add_mixed_rz_sat_bf16_f32:
return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rz_sat_f32,
E, *this);
case NVPTX::BI__nvvm_add_mixed_rm_sat_f16_f32:
case NVPTX::BI__nvvm_add_mixed_rm_sat_bf16_f32:
return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rm_sat_f32,
E, *this);
case NVPTX::BI__nvvm_add_mixed_rp_sat_f16_f32:
case NVPTX::BI__nvvm_add_mixed_rp_sat_bf16_f32:
return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rp_sat_f32,
E, *this);
case NVPTX::BI__nvvm_sub_mixed_f16_f32:
case NVPTX::BI__nvvm_sub_mixed_bf16_f32:
return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_f32, E,
*this);
case NVPTX::BI__nvvm_sub_mixed_rn_f16_f32:
case NVPTX::BI__nvvm_sub_mixed_rn_bf16_f32:
return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rn_f32, E,
*this);
case NVPTX::BI__nvvm_sub_mixed_rz_f16_f32:
case NVPTX::BI__nvvm_sub_mixed_rz_bf16_f32:
return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rz_f32, E,
*this);
case NVPTX::BI__nvvm_sub_mixed_rm_f16_f32:
case NVPTX::BI__nvvm_sub_mixed_rm_bf16_f32:
return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rm_f32, E,
*this);
case NVPTX::BI__nvvm_sub_mixed_rp_f16_f32:
case NVPTX::BI__nvvm_sub_mixed_rp_bf16_f32:
return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rp_f32, E,
*this);
case NVPTX::BI__nvvm_sub_mixed_sat_f16_f32:
case NVPTX::BI__nvvm_sub_mixed_sat_bf16_f32:
return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_sat_f32, E,
*this);
case NVPTX::BI__nvvm_sub_mixed_rn_sat_f16_f32:
case NVPTX::BI__nvvm_sub_mixed_rn_sat_bf16_f32:
return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rn_sat_f32,
E, *this);
case NVPTX::BI__nvvm_sub_mixed_rz_sat_f16_f32:
case NVPTX::BI__nvvm_sub_mixed_rz_sat_bf16_f32:
return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rz_sat_f32,
E, *this);
case NVPTX::BI__nvvm_sub_mixed_rm_sat_f16_f32:
case NVPTX::BI__nvvm_sub_mixed_rm_sat_bf16_f32:
return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rm_sat_f32,
E, *this);
case NVPTX::BI__nvvm_sub_mixed_rp_sat_f16_f32:
case NVPTX::BI__nvvm_sub_mixed_rp_sat_bf16_f32:
return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rp_sat_f32,
E, *this);
case NVPTX::BI__nvvm_fma_mixed_rn_f16_f32:
case NVPTX::BI__nvvm_fma_mixed_rn_bf16_f32:
return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rn_f32, E,
*this);
case NVPTX::BI__nvvm_fma_mixed_rz_f16_f32:
case NVPTX::BI__nvvm_fma_mixed_rz_bf16_f32:
return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rz_f32, E,
*this);
case NVPTX::BI__nvvm_fma_mixed_rm_f16_f32:
case NVPTX::BI__nvvm_fma_mixed_rm_bf16_f32:
return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rm_f32, E,
*this);
case NVPTX::BI__nvvm_fma_mixed_rp_f16_f32:
case NVPTX::BI__nvvm_fma_mixed_rp_bf16_f32:
return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rp_f32, E,
*this);
case NVPTX::BI__nvvm_fma_mixed_rn_sat_f16_f32:
case NVPTX::BI__nvvm_fma_mixed_rn_sat_bf16_f32:
return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rn_sat_f32,
E, *this);
case NVPTX::BI__nvvm_fma_mixed_rz_sat_f16_f32:
case NVPTX::BI__nvvm_fma_mixed_rz_sat_bf16_f32:
return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rz_sat_f32,
E, *this);
case NVPTX::BI__nvvm_fma_mixed_rm_sat_f16_f32:
case NVPTX::BI__nvvm_fma_mixed_rm_sat_bf16_f32:
return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rm_sat_f32,
E, *this);
case NVPTX::BI__nvvm_fma_mixed_rp_sat_f16_f32:
case NVPTX::BI__nvvm_fma_mixed_rp_sat_bf16_f32:
return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rp_sat_f32,
E, *this);
default:
return nullptr;
}
Expand Down
Loading
Loading