Skip to content

Commit 513bf2a

Browse files
committed
[NVPTX] Don't use stack memory when bitcasting to/from 2xi8
`v2i8` is and unsupported type, so we hit the default legalization rules which perform the bitcast in stack memory and is very inefficient on GPU. This adds a custom lowering where we pack `v2i8` into `i16` and from there use another bitcast node to reach the final desired type. And also the inverse unpacking `i16` into `v2i8`.
1 parent c9f01f6 commit 513bf2a

File tree

3 files changed

+86
-0
lines changed

3 files changed

+86
-0
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,10 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
551551
setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v4i8, Custom);
552552
setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v4i8, Custom);
553553
setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v4i8, Custom);
554+
555+
// Custom conversions to/from v2i8.
556+
setOperationAction(ISD::BITCAST, MVT::v2i8, Custom);
557+
554558
// Only logical ops can be done on v4i8 directly, others must be done
555559
// elementwise.
556560
setOperationAction(
@@ -2311,6 +2315,45 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
23112315
return DAG.getBuildVector(Node->getValueType(0), dl, Ops);
23122316
}
23132317

2318+
SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
2319+
// Handle bitcasting to/from v2i8 without hitting the default promotion
2320+
// strategy which goes through stack memory.
2321+
SDNode *Node = Op.getNode();
2322+
SDLoc dl(Node);
2323+
2324+
auto maybeBitcast = [&](EVT vt, SDValue val) {
2325+
if (val->getValueType(0) == vt) {
2326+
return val;
2327+
}
2328+
return DAG.getNode(ISD::BITCAST, dl, vt, val);
2329+
};
2330+
2331+
EVT VT = Op->getValueType(0);
2332+
EVT fromVT = Op->getOperand(0)->getValueType(0);
2333+
2334+
if (VT == MVT::v2i8) {
2335+
SDValue reg = maybeBitcast(MVT::i16, Op->getOperand(0));
2336+
// Promote result to v2i16
2337+
SDValue v0 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i8, reg);
2338+
SDValue C8 = DAG.getConstant(8, dl, MVT::i16);
2339+
SDValue v1 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i8,
2340+
DAG.getNode(ISD::SRL, dl, MVT::i16, {reg, C8}));
2341+
return DAG.getNode(ISD::BUILD_VECTOR, dl, MVT::v2i8, {v0, v1});
2342+
} else if (fromVT == MVT::v2i8) {
2343+
SDValue v0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i8, Op->getOperand(0),
2344+
DAG.getIntPtrConstant(0, dl));
2345+
SDValue v1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i8, Op->getOperand(0),
2346+
DAG.getIntPtrConstant(1, dl));
2347+
SDValue E0 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, v0);
2348+
SDValue E1 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, v1);
2349+
SDValue C8 = DAG.getConstant(8, dl, MVT::i16);
2350+
SDValue reg = DAG.getNode(ISD::OR, dl, MVT::i16,
2351+
{E0, DAG.getNode(ISD::SHL, dl, MVT::i16, {E1, C8})});
2352+
return maybeBitcast(VT, reg);
2353+
}
2354+
return Op;
2355+
}
2356+
23142357
// We can init constant f16x2/v2i16/v4i8 with a single .b32 move. Normally it
23152358
// would get lowered as two constant loads and vector-packing move.
23162359
// Instead we want just a constant move:
@@ -2818,6 +2861,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
28182861
return Op;
28192862
case ISD::BUILD_VECTOR:
28202863
return LowerBUILD_VECTOR(Op, DAG);
2864+
case ISD::BITCAST:
2865+
return LowerBITCAST(Op, DAG);
28212866
case ISD::EXTRACT_SUBVECTOR:
28222867
return Op;
28232868
case ISD::EXTRACT_VECTOR_ELT:
@@ -6413,6 +6458,9 @@ void NVPTXTargetLowering::ReplaceNodeResults(
64136458
switch (N->getOpcode()) {
64146459
default:
64156460
report_fatal_error("Unhandled custom legalization");
6461+
case ISD::BITCAST:
6462+
Results.push_back(LowerBITCAST(SDValue(N, 0), DAG));
6463+
return;
64166464
case ISD::LOAD:
64176465
ReplaceLoadVector(N, DAG, Results);
64186466
return;

llvm/lib/Target/NVPTX/NVPTXISelLowering.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,8 @@ class NVPTXTargetLowering : public TargetLowering {
616616
const NVPTXSubtarget &STI; // cache the subtarget here
617617
SDValue getParamSymbol(SelectionDAG &DAG, int idx, EVT) const;
618618

619+
SDValue LowerBITCAST(SDValue Op, SelectionDAG &DAG) const;
620+
619621
SDValue LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const;
620622
SDValue LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const;
621623
SDValue LowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
; RUN: llc < %s -mtriple=nvptx64-nvidia-cuda -mcpu=sm_90 -mattr=+ptx80 -asm-verbose=false \
2+
; RUN: -O0 -disable-post-ra -frame-pointer=all -verify-machineinstrs \
3+
; RUN: | FileCheck -allow-deprecated-dag-overlap -check-prefixes COMMON,I16x2 %s
4+
; RUN: %if ptxas %{ \
5+
; RUN: llc < %s -mtriple=nvptx64-nvidia-cuda -mcpu=sm_90 -asm-verbose=false \
6+
; RUN: -O0 -disable-post-ra -frame-pointer=all -verify-machineinstrs \
7+
; RUN: | %ptxas-verify -arch=sm_90 \
8+
; RUN: %}
9+
10+
target datalayout = "e-m:o-i64:64-i128:128-n32:64-S128"
11+
12+
; COMMON-LABEL: test_trunc_2xi8(
13+
; COMMON: ld.param.u32 [[R1:%r[0-9]+]], [test_trunc_2xi8_param_0];
14+
; COMMON: mov.b32 {[[RS1:%rs[0-9]+]], [[RS2:%rs[0-9]+]]}, [[R1]];
15+
; COMMON: shl.b16 [[RS3:%rs[0-9]+]], [[RS2]], 8;
16+
; COMMON: and.b16 [[RS4:%rs[0-9]+]], [[RS1]], 255;
17+
; COMMON: or.b16 [[RS5:%rs[0-9]+]], [[RS4]], [[RS3]]
18+
; COMMON: cvt.u32.u16 [[R2:%r[0-9]]], [[RS5]]
19+
; COMMON: st.param.b32 [func_retval0+0], [[R2]];
20+
define i16 @test_trunc_2xi8(<2 x i16> %a) #0 {
21+
%trunc = trunc <2 x i16> %a to <2 x i8>
22+
%res = bitcast <2 x i8> %trunc to i16
23+
ret i16 %res
24+
}
25+
26+
; COMMON-LABEL: test_zext_2xi8(
27+
; COMMON: ld.param.u16 [[RS1:%rs[0-9]+]], [test_zext_2xi8_param_0];
28+
; COMMON: shr.u16 [[RS2:%rs[0-9]+]], [[RS1]], 8;
29+
; COMMON: mov.b32 [[R1:%r[0-9]+]], {[[RS1]], [[RS2]]}
30+
; COMMON: and.b32 [[R2:%r[0-9]+]], [[R1]], 16711935;
31+
; COMMON: st.param.b32 [func_retval0+0], [[R2]];
32+
define <2 x i16> @test_zext_2xi8(i16 %a) #0 {
33+
%vec = bitcast i16 %a to <2 x i8>
34+
%ext = zext <2 x i8> %vec to <2 x i16>
35+
ret <2 x i16> %ext
36+
}

0 commit comments

Comments
 (0)