Skip to content

Commit 92ca087

Browse files
authored
[NVPTX] fix type propagation when expanding Store[V4 -> V8] (#151576)
This was an edge case we missed. Propagate the correct type when expanding a StoreV4 x <2 x float> to StoreV8 x float.
1 parent 3c08498 commit 92ca087

File tree

2 files changed

+42
-8
lines changed

2 files changed

+42
-8
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4917,7 +4917,6 @@ combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
49174917
return SDValue();
49184918

49194919
auto *LD = cast<MemSDNode>(N);
4920-
EVT MemVT = LD->getMemoryVT();
49214920
SDLoc DL(LD);
49224921

49234922
// the new opcode after we double the number of operands
@@ -4958,9 +4957,9 @@ combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
49584957
NewVTs.append(LD->value_begin() + OldNumOutputs, LD->value_end());
49594958

49604959
// Create the new load
4961-
SDValue NewLoad =
4962-
DCI.DAG.getMemIntrinsicNode(Opcode, DL, DCI.DAG.getVTList(NewVTs),
4963-
Operands, MemVT, LD->getMemOperand());
4960+
SDValue NewLoad = DCI.DAG.getMemIntrinsicNode(
4961+
Opcode, DL, DCI.DAG.getVTList(NewVTs), Operands, LD->getMemoryVT(),
4962+
LD->getMemOperand());
49644963

49654964
// Now we use a combination of BUILD_VECTORs and a MERGE_VALUES node to keep
49664965
// the outputs the same. These nodes will be optimized away in later
@@ -5002,7 +5001,6 @@ static SDValue combinePackingMovIntoStore(SDNode *N,
50025001
return SDValue();
50035002

50045003
auto *ST = cast<MemSDNode>(N);
5005-
EVT MemVT = ElementVT.getVectorElementType();
50065004

50075005
// The new opcode after we double the number of operands.
50085006
NVPTXISD::NodeType Opcode;
@@ -5011,11 +5009,9 @@ static SDValue combinePackingMovIntoStore(SDNode *N,
50115009
// Any packed type is legal, so the legalizer will not have lowered
50125010
// ISD::STORE -> NVPTXISD::Store (unless it's under-aligned). We have to do
50135011
// it here.
5014-
MemVT = ST->getMemoryVT();
50155012
Opcode = NVPTXISD::StoreV2;
50165013
break;
50175014
case NVPTXISD::StoreV2:
5018-
MemVT = ST->getMemoryVT();
50195015
Opcode = NVPTXISD::StoreV4;
50205016
break;
50215017
case NVPTXISD::StoreV4:
@@ -5066,7 +5062,7 @@ static SDValue combinePackingMovIntoStore(SDNode *N,
50665062

50675063
// Now we replace the store
50685064
return DCI.DAG.getMemIntrinsicNode(Opcode, SDLoc(N), N->getVTList(), Operands,
5069-
MemVT, ST->getMemOperand());
5065+
ST->getMemoryVT(), ST->getMemOperand());
50705066
}
50715067

50725068
static SDValue PerformStoreCombine(SDNode *N,

llvm/test/CodeGen/NVPTX/fold-movs.ll

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc < %s -mcpu=sm_100 -mattr=+ptx88 -O3 -disable-post-ra \
3+
; RUN: -frame-pointer=all -verify-machineinstrs \
4+
; RUN: | FileCheck %s --check-prefixes=CHECK-F32X2
5+
; RUN: %if ptxas-12.7 %{ \
6+
; RUN: llc < %s -mcpu=sm_100 -mattr=+ptx88 -O3 -disable-post-ra \
7+
; RUN: -frame-pointer=all -verify-machineinstrs | %ptxas-verify -arch=sm_100 \
8+
; RUN: %}
9+
target triple = "nvptx64-nvidia-cuda"
10+
11+
; Since fdiv doesn't support f32x2, this will create BUILD_VECTORs that will be
12+
; folded into the store, turning it into st.global.v8.b32.
13+
define void @writevec(<8 x float> %v1, <8 x float> %v2, ptr addrspace(1) %p) {
14+
; CHECK-F32X2-LABEL: writevec(
15+
; CHECK-F32X2: {
16+
; CHECK-F32X2-NEXT: .reg .b32 %r<25>;
17+
; CHECK-F32X2-NEXT: .reg .b64 %rd<2>;
18+
; CHECK-F32X2-EMPTY:
19+
; CHECK-F32X2-NEXT: // %bb.0:
20+
; CHECK-F32X2-NEXT: ld.param.v4.b32 {%r1, %r2, %r3, %r4}, [writevec_param_0];
21+
; CHECK-F32X2-NEXT: ld.param.v4.b32 {%r5, %r6, %r7, %r8}, [writevec_param_0+16];
22+
; CHECK-F32X2-NEXT: ld.param.v4.b32 {%r9, %r10, %r11, %r12}, [writevec_param_1+16];
23+
; CHECK-F32X2-NEXT: div.rn.f32 %r13, %r8, %r12;
24+
; CHECK-F32X2-NEXT: div.rn.f32 %r14, %r7, %r11;
25+
; CHECK-F32X2-NEXT: div.rn.f32 %r15, %r6, %r10;
26+
; CHECK-F32X2-NEXT: div.rn.f32 %r16, %r5, %r9;
27+
; CHECK-F32X2-NEXT: ld.param.v4.b32 {%r17, %r18, %r19, %r20}, [writevec_param_1];
28+
; CHECK-F32X2-NEXT: div.rn.f32 %r21, %r4, %r20;
29+
; CHECK-F32X2-NEXT: div.rn.f32 %r22, %r3, %r19;
30+
; CHECK-F32X2-NEXT: div.rn.f32 %r23, %r2, %r18;
31+
; CHECK-F32X2-NEXT: div.rn.f32 %r24, %r1, %r17;
32+
; CHECK-F32X2-NEXT: ld.param.b64 %rd1, [writevec_param_2];
33+
; CHECK-F32X2-NEXT: st.global.v8.b32 [%rd1], {%r24, %r23, %r22, %r21, %r16, %r15, %r14, %r13};
34+
; CHECK-F32X2-NEXT: ret;
35+
%v = fdiv <8 x float> %v1, %v2
36+
store <8 x float> %v, ptr addrspace(1) %p, align 32
37+
ret void
38+
}

0 commit comments

Comments
 (0)