Skip to content

Commit 403c693

Browse files
committed
[NVPTX] Fixup EXT_LOAD lowering for i128 values
1 parent 71039bb commit 403c693

File tree

3 files changed

+56
-12
lines changed

3 files changed

+56
-12
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "llvm/Support/CommandLine.h"
2727
#include "llvm/Support/ErrorHandling.h"
2828
#include "llvm/Support/FormatVariadic.h"
29+
#include "llvm/Support/MathExtras.h"
2930
#include <optional>
3031

3132
using namespace llvm;
@@ -1141,6 +1142,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
11411142
else
11421143
FromType = getLdStRegType(ScalarVT);
11431144

1145+
assert(isPowerOf2_32(FromTypeWidth) && FromTypeWidth >= 8 &&
1146+
FromTypeWidth <= 128 && "Invalid width for load");
1147+
11441148
// Create the machine instruction DAG
11451149
SDValue Offset, Base;
11461150
SelectADDR(N->getOperand(1), Base, Offset);
@@ -1236,6 +1240,9 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
12361240
FromType = NVPTX::PTXLdStInstCode::Untyped;
12371241
}
12381242

1243+
assert(isPowerOf2_32(FromTypeWidth) && FromTypeWidth >= 8 &&
1244+
FromTypeWidth <= 128 && TotalWidth <= 128 && "Invalid width for load");
1245+
12391246
SDValue Offset, Base;
12401247
SelectADDR(N->getOperand(1), Base, Offset);
12411248
SDValue Ops[] = {getI32Imm(Ordering, DL),
@@ -1453,6 +1460,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
14531460
// Create the machine instruction DAG
14541461
SDValue Value = PlainStore ? PlainStore->getValue() : AtomicStore->getVal();
14551462

1463+
assert(isPowerOf2_32(ToTypeWidth) && ToTypeWidth >= 8 &&
1464+
ToTypeWidth <= 128 && "Invalid width for store");
1465+
14561466
SDValue Offset, Base;
14571467
SelectADDR(ST->getBasePtr(), Base, Offset);
14581468

@@ -1537,6 +1547,9 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
15371547
ToType = NVPTX::PTXLdStInstCode::Untyped;
15381548
}
15391549

1550+
assert(isPowerOf2_32(ToTypeWidth) && ToTypeWidth >= 8 &&
1551+
ToTypeWidth <= 128 && TotalWidth <= 128 && "Invalid width for store");
1552+
15401553
SDValue Offset, Base;
15411554
SelectADDR(N2, Base, Offset);
15421555

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3191,20 +3191,22 @@ SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
31913191

31923192
SDValue
31933193
NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
3194-
SDNode *N = Op.getNode();
3194+
MemSDNode *N = cast<MemSDNode>(Op.getNode());
31953195
SDValue Val = N->getOperand(1);
31963196
SDLoc DL(N);
3197-
EVT ValVT = Val.getValueType();
3197+
const EVT ValVT = Val.getValueType();
3198+
const EVT MemVT = N->getMemoryVT();
3199+
if (ValVT != MemVT)
3200+
return SDValue();
31983201

31993202
const auto NumEltsAndEltVT = getVectorLoweringShape(ValVT);
32003203
if (!NumEltsAndEltVT)
32013204
return SDValue();
32023205
const auto [NumElts, EltVT] = NumEltsAndEltVT.value();
32033206

3204-
MemSDNode *MemSD = cast<MemSDNode>(N);
32053207
const DataLayout &TD = DAG.getDataLayout();
32063208

3207-
Align Alignment = MemSD->getAlign();
3209+
Align Alignment = N->getAlign();
32083210
Align PrefAlign = TD.getPrefTypeAlign(ValVT.getTypeForEVT(*DAG.getContext()));
32093211
if (Alignment < PrefAlign) {
32103212
// This store is not sufficiently aligned, so bail out and let this vector
@@ -3267,7 +3269,7 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
32673269

32683270
SDValue NewSt =
32693271
DAG.getMemIntrinsicNode(Opcode, DL, DAG.getVTList(MVT::Other), Ops,
3270-
MemSD->getMemoryVT(), MemSD->getMemOperand());
3272+
N->getMemoryVT(), N->getMemOperand());
32713273

32723274
// return DCI.CombineTo(N, NewSt, true);
32733275
return NewSt;
@@ -5762,20 +5764,20 @@ static void ReplaceBITCAST(SDNode *Node, SelectionDAG &DAG,
57625764
/// ReplaceVectorLoad - Convert vector loads into multi-output scalar loads.
57635765
static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
57645766
SmallVectorImpl<SDValue> &Results) {
5765-
const EVT ResVT = N->getValueType(0);
5766-
SDLoc DL(N);
5767+
LoadSDNode *LD = cast<LoadSDNode>(N);
5768+
const EVT ResVT = LD->getValueType(0);
5769+
const EVT MemVT = LD->getMemoryVT();
5770+
if (ResVT != MemVT)
5771+
return;
57675772

57685773
const auto NumEltsAndEltVT = getVectorLoweringShape(ResVT);
57695774
if (!NumEltsAndEltVT)
57705775
return;
57715776
const auto [NumElts, EltVT] = NumEltsAndEltVT.value();
57725777

5773-
LoadSDNode *LD = cast<LoadSDNode>(N);
5774-
57755778
Align Alignment = LD->getAlign();
57765779
const auto &TD = DAG.getDataLayout();
5777-
Align PrefAlign =
5778-
TD.getPrefTypeAlign(LD->getMemoryVT().getTypeForEVT(*DAG.getContext()));
5780+
Align PrefAlign = TD.getPrefTypeAlign(MemVT.getTypeForEVT(*DAG.getContext()));
57795781
if (Alignment < PrefAlign) {
57805782
// This load is not sufficiently aligned, so bail out and let this vector
57815783
// load be scalarized. Note that we may still be able to emit smaller
@@ -5806,9 +5808,10 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
58065808
break;
58075809
}
58085810
}
5811+
SDLoc DL(LD);
58095812

58105813
// Copy regular operands
5811-
SmallVector<SDValue, 8> OtherOps(N->ops());
5814+
SmallVector<SDValue, 8> OtherOps(LD->ops());
58125815

58135816
// The select routine does not have access to the LoadSDNode instance, so
58145817
// pass along the extension information
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc < %s -O0 -mcpu=sm_20 | FileCheck %s
3+
; RUN: %if ptxas %{ llc < %s -O0 -mcpu=sm_20 | %ptxas-verify %}
4+
5+
target triple = "nvptx64-nvidia-cuda"
6+
7+
define i128 @foo() {
8+
; CHECK-LABEL: foo(
9+
; CHECK: {
10+
; CHECK-NEXT: .reg .b64 %rd<3>;
11+
; CHECK-EMPTY:
12+
; CHECK-NEXT: // %bb.0: // %entry
13+
; CHECK-NEXT: bra.uni $L__BB0_1;
14+
; CHECK-NEXT: $L__BB0_1: // %while.cond
15+
; CHECK-NEXT: // =>This Inner Loop Header: Depth=1
16+
; CHECK-NEXT: mov.b64 %rd1, 0;
17+
; CHECK-NEXT: ld.u8 %rd2, [%rd1];
18+
; CHECK-NEXT: st.v2.u64 [%rd1], {%rd2, %rd1};
19+
; CHECK-NEXT: bra.uni $L__BB0_1;
20+
entry:
21+
br label %while.cond
22+
23+
while.cond: ; preds = %while.cond, %entry
24+
%0 = load i8, ptr null, align 1
25+
%conv = zext i8 %0 to i128
26+
store i128 %conv, ptr null, align 16
27+
br label %while.cond
28+
}

0 commit comments

Comments
 (0)