diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index a95cba586b8fc..ba21733e96165 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -409,6 +409,13 @@ VectorizePTXValueVTs(const SmallVectorImpl &ValueVTs, return VectorInfo; } +static SDValue MaybeBitcast(SelectionDAG &DAG, SDLoc DL, EVT VT, + SDValue Value) { + if (Value->getValueType(0) == VT) + return Value; + return DAG.getNode(ISD::BITCAST, DL, VT, Value); +} + // NVPTXTargetLowering Constructor. NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, const NVPTXSubtarget &STI) @@ -551,6 +558,10 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v4i8, Custom); setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v4i8, Custom); setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v4i8, Custom); + + // Custom conversions to/from v2i8. + setOperationAction(ISD::BITCAST, MVT::v2i8, Custom); + // Only logical ops can be done on v4i8 directly, others must be done // elementwise. setOperationAction( @@ -2311,6 +2322,30 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const { return DAG.getBuildVector(Node->getValueType(0), dl, Ops); } +SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const { + // Handle bitcasting from v2i8 without hitting the default promotion + // strategy which goes through stack memory. + EVT FromVT = Op->getOperand(0)->getValueType(0); + if (FromVT != MVT::v2i8) { + return Op; + } + + // Pack vector elements into i16 and bitcast to final type + SDLoc DL(Op); + SDValue Vec0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8, + Op->getOperand(0), DAG.getIntPtrConstant(0, DL)); + SDValue Vec1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8, + Op->getOperand(0), DAG.getIntPtrConstant(1, DL)); + SDValue Extend0 = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i16, Vec0); + SDValue Extend1 = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i16, Vec1); + SDValue Const8 = DAG.getConstant(8, DL, MVT::i16); + SDValue AsInt = DAG.getNode( + ISD::OR, DL, MVT::i16, + {Extend0, DAG.getNode(ISD::SHL, DL, MVT::i16, {Extend1, Const8})}); + EVT ToVT = Op->getValueType(0); + return MaybeBitcast(DAG, DL, ToVT, AsInt); +} + // We can init constant f16x2/v2i16/v4i8 with a single .b32 move. Normally it // would get lowered as two constant loads and vector-packing move. // Instead we want just a constant move: @@ -2818,6 +2853,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { return Op; case ISD::BUILD_VECTOR: return LowerBUILD_VECTOR(Op, DAG); + case ISD::BITCAST: + return LowerBITCAST(Op, DAG); case ISD::EXTRACT_SUBVECTOR: return Op; case ISD::EXTRACT_VECTOR_ELT: @@ -6128,6 +6165,28 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N, return SDValue(); } +static void ReplaceBITCAST(SDNode *Node, SelectionDAG &DAG, + SmallVectorImpl &Results) { + // Handle bitcasting to v2i8 without hitting the default promotion + // strategy which goes through stack memory. + SDValue Op(Node, 0); + EVT ToVT = Op->getValueType(0); + if (ToVT != MVT::v2i8) { + return; + } + + // Bitcast to i16 and unpack elements into a vector + SDLoc DL(Node); + SDValue AsInt = MaybeBitcast(DAG, DL, MVT::i16, Op->getOperand(0)); + SDValue Vec0 = DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, AsInt); + SDValue Const8 = DAG.getConstant(8, DL, MVT::i16); + SDValue Vec1 = + DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, + DAG.getNode(ISD::SRL, DL, MVT::i16, {AsInt, Const8})); + Results.push_back( + DAG.getNode(ISD::BUILD_VECTOR, DL, MVT::v2i8, {Vec0, Vec1})); +} + /// ReplaceVectorLoad - Convert vector loads into multi-output scalar loads. static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG, SmallVectorImpl &Results) { @@ -6413,6 +6472,9 @@ void NVPTXTargetLowering::ReplaceNodeResults( switch (N->getOpcode()) { default: report_fatal_error("Unhandled custom legalization"); + case ISD::BITCAST: + ReplaceBITCAST(N, DAG, Results); + return; case ISD::LOAD: ReplaceLoadVector(N, DAG, Results); return; diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h index 824a659671967..13153f4830b69 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h @@ -616,6 +616,8 @@ class NVPTXTargetLowering : public TargetLowering { const NVPTXSubtarget &STI; // cache the subtarget here SDValue getParamSymbol(SelectionDAG &DAG, int idx, EVT) const; + SDValue LowerBITCAST(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const; SDValue LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const; SDValue LowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const; diff --git a/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll b/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll new file mode 100644 index 0000000000000..df9c3e59b0e6b --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll @@ -0,0 +1,33 @@ +; RUN: llc < %s -mtriple=nvptx64-nvidia-cuda -mcpu=sm_90 -mattr=+ptx80 -asm-verbose=false \ +; RUN: -O0 -disable-post-ra -frame-pointer=all -verify-machineinstrs \ +; RUN: | FileCheck %s +; RUN: %if ptxas %{ \ +; RUN: llc < %s -mtriple=nvptx64-nvidia-cuda -mcpu=sm_90 -asm-verbose=false \ +; RUN: -O0 -disable-post-ra -frame-pointer=all -verify-machineinstrs \ +; RUN: | %ptxas-verify -arch=sm_90 \ +; RUN: %} + +target datalayout = "e-m:o-i64:64-i128:128-n32:64-S128" + +; CHECK-LABEL: test_bitcast_2xi8_i16( +; CHECK: ld.param.u32 %r1, [test_bitcast_2xi8_i16_param_0]; +; CHECK: mov.b32 {%rs1, %rs2}, %r1; +; CHECK: shl.b16 %rs3, %rs2, 8; +; CHECK: and.b16 %rs4, %rs1, 255; +; CHECK: or.b16 %rs5, %rs4, %rs3; +; CHECK: cvt.u32.u16 %r2, %rs5; +; CHECK: st.param.b32 [func_retval0], %r2; +define i16 @test_bitcast_2xi8_i16(<2 x i8> %a) { + %res = bitcast <2 x i8> %a to i16 + ret i16 %res +} + +; CHECK-LABEL: test_bitcast_i16_2xi8( +; CHECK: ld.param.u16 %rs1, [test_bitcast_i16_2xi8_param_0]; +; CHECK: shr.u16 %rs2, %rs1, 8; +; CHECK: mov.b32 %r1, {%rs1, %rs2}; +; CHECK: st.param.b32 [func_retval0], %r1; +define <2 x i8> @test_bitcast_i16_2xi8(i16 %a) { + %res = bitcast i16 %a to <2 x i8> + ret <2 x i8> %res +}