Skip to content

Commit 5bd9195

Browse files
committed
address comments
1 parent 52c1252 commit 5bd9195

File tree

2 files changed

+40
-60
lines changed

2 files changed

+40
-60
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 14 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2845,19 +2845,6 @@ static SDValue LowerClusterLaunchControlQueryCancel(SDValue Op,
28452845
{TryCancelResponse0, TryCancelResponse1});
28462846
}
28472847

2848-
bool isCvtRSReluIntrinsic(Intrinsic::ID ID) {
2849-
switch (ID) {
2850-
case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite:
2851-
case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite:
2852-
case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite:
2853-
case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite:
2854-
case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite:
2855-
return true;
2856-
default:
2857-
return false;
2858-
}
2859-
}
2860-
28612848
static SDValue lowerCvtRSIntrinsics(SDValue Op, SelectionDAG &DAG) {
28622849
SDNode *N = Op.getNode();
28632850
SDLoc DL(N);
@@ -2867,43 +2854,44 @@ static SDValue lowerCvtRSIntrinsics(SDValue Op, SelectionDAG &DAG) {
28672854
unsigned IntrinsicID = N->getConstantOperandVal(0);
28682855

28692856
uint32_t CvtModeFlag = NVPTX::PTXCvtMode::CvtMode::RS;
2870-
if (isCvtRSReluIntrinsic(IntrinsicID))
2871-
CvtModeFlag |= NVPTX::PTXCvtMode::CvtMode::RELU_FLAG;
2872-
2873-
SDValue Float1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32, F32Vec,
2874-
DAG.getIntPtrConstant(0, DL));
2875-
SDValue Float2 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32, F32Vec,
2876-
DAG.getIntPtrConstant(1, DL));
2877-
SDValue Float3 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32, F32Vec,
2878-
DAG.getIntPtrConstant(2, DL));
2879-
SDValue Float4 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32, F32Vec,
2880-
DAG.getIntPtrConstant(3, DL));
2857+
2858+
// Extract the 4 float elements from the vector
2859+
SmallVector<SDValue, 6> Ops;
2860+
for (unsigned i = 0; i < 4; ++i) {
2861+
Ops.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32, F32Vec,
2862+
DAG.getIntPtrConstant(i, DL)));
2863+
}
28812864

28822865
auto OpSignature =
28832866
[&]() -> std::pair<NVPTXISD::NodeType, MVT::SimpleValueType> {
28842867
switch (IntrinsicID) {
28852868
case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite:
2869+
CvtModeFlag |= NVPTX::PTXCvtMode::CvtMode::RELU_FLAG;
28862870
case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite:
28872871
return {NVPTXISD::CVT_E4M3X4_F32X4_RS_SF, MVT::v4i8};
28882872
case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite:
2873+
CvtModeFlag |= NVPTX::PTXCvtMode::CvtMode::RELU_FLAG;
28892874
case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite:
28902875
return {NVPTXISD::CVT_E5M2X4_F32X4_RS_SF, MVT::v4i8};
28912876
case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite:
2877+
CvtModeFlag |= NVPTX::PTXCvtMode::CvtMode::RELU_FLAG;
28922878
case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite:
28932879
return {NVPTXISD::CVT_E2M3X4_F32X4_RS_SF, MVT::v4i8};
28942880
case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite:
2881+
CvtModeFlag |= NVPTX::PTXCvtMode::CvtMode::RELU_FLAG;
28952882
case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite:
28962883
return {NVPTXISD::CVT_E3M2X4_F32X4_RS_SF, MVT::v4i8};
28972884
case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite:
2885+
CvtModeFlag |= NVPTX::PTXCvtMode::CvtMode::RELU_FLAG;
28982886
case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite:
28992887
return {NVPTXISD::CVT_E2M1X4_F32X4_RS_SF, MVT::i16};
29002888
default:
29012889
llvm_unreachable("unsupported/unhandled intrinsic");
29022890
}
29032891
}();
29042892

2905-
SDValue Ops[] = {Float1, Float2, Float3,
2906-
Float4, RBits, DAG.getConstant(CvtModeFlag, DL, MVT::i32)};
2893+
Ops.push_back(RBits);
2894+
Ops.push_back(DAG.getConstant(CvtModeFlag, DL, MVT::i32));
29072895

29082896
return DAG.getNode(OpSignature.first, DL, OpSignature.second, Ops);
29092897
}

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 26 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1950,57 +1950,49 @@ let Predicates = [hasPTX<86>, hasSM<100>, hasArchAccelFeatures] in {
19501950
(CVT_bf16x2_ue8m0x2 $a)>;
19511951
}
19521952

1953-
def SDT_CVT_F32X4_TO_FP8X4_RS :
1953+
def SDT_CVT_F32X4_TO_FPX4_RS_VEC :
19541954
SDTypeProfile<1, 6, [SDTCisVec<0>, SDTCisFP<1>, SDTCisFP<2>, SDTCisFP<3>,
19551955
SDTCisFP<4>, SDTCisInt<5>, SDTCisInt<6>]>;
19561956

1957-
def SDT_CVT_F32X4_TO_FP6X4_RS :
1958-
SDTypeProfile<1, 6, [SDTCisVec<0>, SDTCisFP<1>, SDTCisFP<2>, SDTCisFP<3>,
1959-
SDTCisFP<4>, SDTCisInt<5>, SDTCisInt<6>]>;
1960-
1961-
def SDT_CVT_F32X4_TO_FP4X4_RS :
1957+
def SDT_CVT_F32X4_TO_FPX4_RS_INT :
19621958
SDTypeProfile<1, 6, [SDTCisInt<0>, SDTCisFP<1>, SDTCisFP<2>, SDTCisFP<3>,
19631959
SDTCisFP<4>, SDTCisInt<5>, SDTCisInt<6>]>;
19641960

19651961
class CVT_F32X4_TO_FPX4_RS_SF_NODE<string FPName, SDTypeProfile SDT> :
19661962
SDNode<"NVPTXISD::CVT_" # FPName # "X4_F32X4_RS_SF", SDT, []>;
1963+
1964+
multiclass CVT_F32X4_TO_FPX4_RS_SF_VEC<string FPName, VTVec RetTy> {
1965+
def : Pat<(RetTy (CVT_F32X4_TO_FPX4_RS_SF_NODE<!toupper(FPName),
1966+
SDT_CVT_F32X4_TO_FPX4_RS_VEC>
1967+
f32:$f1, f32:$f2, f32:$f3, f32:$f4, i32:$rbits, CvtRS)),
1968+
(!cast<NVPTXInst>(CVT_ # FPName # "x4_f32x4_rs_sf")
1969+
$f1, $f2, $f3, $f4, $rbits, CvtRS)>;
1970+
1971+
def : Pat<(RetTy (CVT_F32X4_TO_FPX4_RS_SF_NODE<!toupper(FPName),
1972+
SDT_CVT_F32X4_TO_FPX4_RS_VEC>
1973+
f32:$f1, f32:$f2, f32:$f3, f32:$f4, i32:$rbits, CvtRS_RELU)),
1974+
(!cast<NVPTXInst>(CVT_ # FPName # "x4_f32x4_rs_sf")
1975+
$f1, $f2, $f3, $f4, $rbits, CvtRS_RELU)>;
1976+
}
19671977

19681978
// RS rounding mode conversions
19691979
let Predicates = [hasPTX<87>, hasSM100aOrSM103a] in {
19701980
// FP8x4 conversions
1971-
def : Pat<(v4i8 (CVT_F32X4_TO_FPX4_RS_SF_NODE<"E4M3", SDT_CVT_F32X4_TO_FP8X4_RS>
1972-
f32:$f1, f32:$f2, f32:$f3, f32:$f4, i32:$rbits, CvtRS)),
1973-
(CVT_e4m3x4_f32x4_rs_sf $f1, $f2, $f3, $f4, $rbits, CvtRS)>;
1974-
def : Pat<(v4i8 (CVT_F32X4_TO_FPX4_RS_SF_NODE<"E5M2", SDT_CVT_F32X4_TO_FP8X4_RS>
1975-
f32:$f1, f32:$f2, f32:$f3, f32:$f4, i32:$rbits, CvtRS)),
1976-
(CVT_e5m2x4_f32x4_rs_sf $f1, $f2, $f3, $f4, $rbits, CvtRS)>;
1977-
def : Pat<(v4i8 (CVT_F32X4_TO_FPX4_RS_SF_NODE<"E4M3", SDT_CVT_F32X4_TO_FP8X4_RS>
1978-
f32:$f1, f32:$f2, f32:$f3, f32:$f4, i32:$rbits, CvtRS_RELU)),
1979-
(CVT_e4m3x4_f32x4_rs_sf $f1, $f2, $f3, $f4, $rbits, CvtRS_RELU)>;
1980-
def : Pat<(v4i8 (CVT_F32X4_TO_FPX4_RS_SF_NODE<"E5M2", SDT_CVT_F32X4_TO_FP8X4_RS>
1981-
f32:$f1, f32:$f2, f32:$f3, f32:$f4, i32:$rbits, CvtRS_RELU)),
1982-
(CVT_e5m2x4_f32x4_rs_sf $f1, $f2, $f3, $f4, $rbits, CvtRS_RELU)>;
1981+
def : CVT_F32X4_TO_FPX4_RS_SF_VEC<"E4M3", v4i8>;
1982+
def : CVT_F32X4_TO_FPX4_RS_SF_VEC<"E5M2", v4i8>;
19831983

19841984
// FP6x4 conversions
1985-
def : Pat<(v4i8 (CVT_F32X4_TO_FPX4_RS_SF_NODE<"E2M3", SDT_CVT_F32X4_TO_FP6X4_RS>
1986-
f32:$f1, f32:$f2, f32:$f3, f32:$f4, i32:$rbits, CvtRS)),
1987-
(CVT_e2m3x4_f32x4_rs_sf $f1, $f2, $f3, $f4, $rbits, CvtRS)>;
1988-
def : Pat<(v4i8 (CVT_F32X4_TO_FPX4_RS_SF_NODE<"E3M2", SDT_CVT_F32X4_TO_FP6X4_RS>
1989-
f32:$f1, f32:$f2, f32:$f3, f32:$f4, i32:$rbits, CvtRS)),
1990-
(CVT_e3m2x4_f32x4_rs_sf $f1, $f2, $f3, $f4, $rbits, CvtRS)>;
1991-
def : Pat<(v4i8 (CVT_F32X4_TO_FPX4_RS_SF_NODE<"E2M3", SDT_CVT_F32X4_TO_FP6X4_RS>
1992-
f32:$f1, f32:$f2, f32:$f3, f32:$f4, i32:$rbits, CvtRS_RELU)),
1993-
(CVT_e2m3x4_f32x4_rs_sf $f1, $f2, $f3, $f4, $rbits, CvtRS_RELU)>;
1994-
def : Pat<(v4i8 (CVT_F32X4_TO_FPX4_RS_SF_NODE<"E3M2", SDT_CVT_F32X4_TO_FP6X4_RS>
1995-
f32:$f1, f32:$f2, f32:$f3, f32:$f4, i32:$rbits, CvtRS_RELU)),
1996-
(CVT_e3m2x4_f32x4_rs_sf $f1, $f2, $f3, $f4, $rbits, CvtRS_RELU)>;
1985+
def : CVT_F32X4_TO_FPX4_RS_SF_VEC<"E2M3", v4i8>;
1986+
def : CVT_F32X4_TO_FPX4_RS_SF_VEC<"E3M2", v4i8>;
19971987

19981988
// FP4x4 conversions
1999-
def : Pat<(i16 (CVT_F32X4_TO_FPX4_RS_SF_NODE<"E2M1", SDT_CVT_F32X4_TO_FP4X4_RS>
2000-
f32:$f1, f32:$f2, f32:$f3, f32:$f4, i32:$rbits, CvtRS)),
1989+
def : Pat<(i16 (CVT_F32X4_TO_FPX4_RS_SF_NODE<"E2M1",
1990+
SDT_CVT_F32X4_TO_FPX4_RS_INT>
1991+
f32:$f1, f32:$f2, f32:$f3, f32:$f4, i32:$rbits, CvtRS)),
20011992
(CVT_e2m1x4_f32x4_rs_sf $f1, $f2, $f3, $f4, $rbits, CvtRS)>;
2002-
def : Pat<(i16 (CVT_F32X4_TO_FPX4_RS_SF_NODE<"E2M1", SDT_CVT_F32X4_TO_FP4X4_RS>
2003-
f32:$f1, f32:$f2, f32:$f3, f32:$f4, i32:$rbits, CvtRS_RELU)),
1993+
def : Pat<(i16 (CVT_F32X4_TO_FPX4_RS_SF_NODE<"E2M1",
1994+
SDT_CVT_F32X4_TO_FPX4_RS_INT>
1995+
f32:$f1, f32:$f2, f32:$f3, f32:$f4, i32:$rbits, CvtRS_RELU)),
20041996
(CVT_e2m1x4_f32x4_rs_sf $f1, $f2, $f3, $f4, $rbits, CvtRS_RELU)>;
20051997
}
20061998

0 commit comments

Comments
 (0)