@@ -829,7 +829,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
829829 setTargetDAGCombine ({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
830830 ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
831831 ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::FP_ROUND,
832- ISD::TRUNCATE, ISD::LOAD, ISD::BITCAST});
832+ ISD::TRUNCATE, ISD::LOAD, ISD::STORE, ISD:: BITCAST});
833833
834834 // setcc for f16x2 and bf16x2 needs special handling to prevent
835835 // legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -3143,10 +3143,10 @@ SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
31433143 if (Op.getValueType () == MVT::i1)
31443144 return LowerLOADi1 (Op, DAG);
31453145
3146- // v2f16/v2bf16/v2i16/v4i8 are legal, so we can't rely on legalizer to handle
3147- // unaligned loads and have to handle it here.
3146+ // v2f16/v2bf16/v2i16/v4i8/v2f32 are legal, so we can't rely on legalizer to
3147+ // handle unaligned loads and have to handle it here.
31483148 EVT VT = Op.getValueType ();
3149- if (Isv2x16VT (VT) || VT == MVT::v4i8) {
3149+ if (Isv2x16VT (VT) || VT == MVT::v4i8 || VT == MVT::v2f32 ) {
31503150 LoadSDNode *Load = cast<LoadSDNode>(Op);
31513151 EVT MemVT = Load->getMemoryVT ();
31523152 if (!allowsMemoryAccessForAlignment (*DAG.getContext (), DAG.getDataLayout (),
@@ -3190,22 +3190,22 @@ SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
31903190 if (VT == MVT::i1)
31913191 return LowerSTOREi1 (Op, DAG);
31923192
3193- // v2f16 is legal, so we can't rely on legalizer to handle unaligned
3194- // stores and have to handle it here.
3195- if ((Isv2x16VT (VT) || VT == MVT::v4i8) &&
3193+ // v2f16/v2bf16/v2i16/v4i8/v2f32 are legal, so we can't rely on legalizer to
3194+ // handle unaligned stores and have to handle it here.
3195+ if ((Isv2x16VT (VT) || VT == MVT::v4i8 || VT == MVT::v2f32 ) &&
31963196 !allowsMemoryAccessForAlignment (*DAG.getContext (), DAG.getDataLayout (),
31973197 VT, *Store->getMemOperand ()))
31983198 return expandUnalignedStore (Store, DAG);
31993199
3200- // v2f16, v2bf16 and v2i16 don't need special handling.
3201- if (Isv2x16VT (VT) || VT == MVT::v4i8)
3200+ // v2f16/ v2bf16/ v2i16/v4i8/v2f32 don't need special handling.
3201+ if (Isv2x16VT (VT) || VT == MVT::v4i8 || VT == MVT::v2f32 )
32023202 return SDValue ();
32033203
32043204 return LowerSTOREVector (Op, DAG);
32053205}
32063206
3207- SDValue
3208- NVPTXTargetLowering::LowerSTOREVector (SDValue Op, SelectionDAG &DAG) const {
3207+ static SDValue convertVectorStore (SDValue Op, SelectionDAG &DAG,
3208+ const SmallVectorImpl<SDValue> &Elements) {
32093209 MemSDNode *N = cast<MemSDNode>(Op.getNode ());
32103210 SDValue Val = N->getOperand (1 );
32113211 SDLoc DL (N);
@@ -3266,6 +3266,8 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
32663266 NumEltsPerSubVector);
32673267 Ops.push_back (DAG.getBuildVector (EltVT, DL, SubVectorElts));
32683268 }
3269+ } else if (!Elements.empty ()) {
3270+ Ops.insert (Ops.end (), Elements.begin (), Elements.end ());
32693271 } else {
32703272 SDValue V = DAG.getBitcast (MVT::getVectorVT (EltVT, NumElts), Val);
32713273 for (const unsigned I : llvm::seq (NumElts)) {
@@ -3289,10 +3291,19 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
32893291 DAG.getMemIntrinsicNode (Opcode, DL, DAG.getVTList (MVT::Other), Ops,
32903292 N->getMemoryVT (), N->getMemOperand ());
32913293
3292- // return DCI.CombineTo(N, NewSt, true);
32933294 return NewSt;
32943295}
32953296
3297+ // Default variant where we don't pass in elements.
3298+ static SDValue convertVectorStore (SDValue Op, SelectionDAG &DAG) {
3299+ return convertVectorStore (Op, DAG, SmallVector<SDValue>{});
3300+ }
3301+
3302+ SDValue NVPTXTargetLowering::LowerSTOREVector (SDValue Op,
3303+ SelectionDAG &DAG) const {
3304+ return convertVectorStore (Op, DAG);
3305+ }
3306+
32963307// st i1 v, addr
32973308// =>
32983309// v1 = zxt v to i16
@@ -5413,6 +5424,9 @@ static SDValue PerformStoreCombineHelper(SDNode *N,
54135424 // -->
54145425 // StoreRetvalV2 {a, b}
54155426 // likewise for V2 -> V4 case
5427+ //
5428+ // We also handle target-independent stores, which require us to first
5429+ // convert to StoreV2.
54165430
54175431 std::optional<NVPTXISD::NodeType> NewOpcode;
54185432 switch (N->getOpcode ()) {
@@ -5438,8 +5452,8 @@ static SDValue PerformStoreCombineHelper(SDNode *N,
54385452 SDValue CurrentOp = N->getOperand (I);
54395453 if (CurrentOp->getOpcode () == ISD::BUILD_VECTOR) {
54405454 assert (CurrentOp.getValueType () == MVT::v2f32);
5441- NewOps.push_back (CurrentOp.getNode ()-> getOperand (0 ));
5442- NewOps.push_back (CurrentOp.getNode ()-> getOperand (1 ));
5455+ NewOps.push_back (CurrentOp.getOperand (0 ));
5456+ NewOps.push_back (CurrentOp.getOperand (1 ));
54435457 } else {
54445458 NewOps.clear ();
54455459 break ;
@@ -6210,6 +6224,18 @@ static SDValue PerformBITCASTCombine(SDNode *N,
62106224 return SDValue ();
62116225}
62126226
6227+ static SDValue PerformStoreCombine (SDNode *N,
6228+ TargetLowering::DAGCombinerInfo &DCI) {
6229+ // check if the store'd value can be scalarized
6230+ SDValue StoredVal = N->getOperand (1 );
6231+ if (StoredVal.getValueType () == MVT::v2f32 &&
6232+ StoredVal.getOpcode () == ISD::BUILD_VECTOR) {
6233+ SmallVector<SDValue> Elements (StoredVal->op_values ());
6234+ return convertVectorStore (SDValue (N, 0 ), DCI.DAG , Elements);
6235+ }
6236+ return SDValue ();
6237+ }
6238+
62136239SDValue NVPTXTargetLowering::PerformDAGCombine (SDNode *N,
62146240 DAGCombinerInfo &DCI) const {
62156241 CodeGenOptLevel OptLevel = getTargetMachine ().getOptLevel ();
@@ -6239,6 +6265,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
62396265 case NVPTXISD::LoadParam:
62406266 case NVPTXISD::LoadParamV2:
62416267 return PerformLoadCombine (N, DCI);
6268+ case ISD::STORE:
6269+ return PerformStoreCombine (N, DCI);
62426270 case NVPTXISD::StoreParam:
62436271 case NVPTXISD::StoreParamV2:
62446272 case NVPTXISD::StoreParamV4:
0 commit comments