-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[NVPTX] legalize v2i32 to improve compatibility with v2f32 #153478
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NVPTX] legalize v2i32 to improve compatibility with v2f32 #153478
Conversation
@llvm/pr-subscribers-backend-nvptx Author: Princeton Ferro (Prince781) ChangesTransform:
Since v2f32 is legal but v2i32 is not, v2i32 build_vector would be legalized as bitwise ops on i64, when we want each 32-bit element to be in its own register. Fixes #153109 Full diff: https://github.com/llvm/llvm-project/pull/153478.diff 2 Files Affected:
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 3daf25d551520..fcabe49e09c6c 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -892,10 +892,11 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setOperationAction(ISD::UMUL_LOHI, MVT::i64, Expand);
// We have some custom DAG combine patterns for these nodes
- setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
- ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
- ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::LOAD,
- ISD::STORE, ISD::ZERO_EXTEND, ISD::SIGN_EXTEND});
+ setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::BITCAST,
+ ISD::EXTRACT_VECTOR_ELT, ISD::FADD, ISD::MUL, ISD::SHL,
+ ISD::SREM, ISD::UREM, ISD::VSELECT, ISD::BUILD_VECTOR,
+ ISD::ADDRSPACECAST, ISD::LOAD, ISD::STORE,
+ ISD::ZERO_EXTEND, ISD::SIGN_EXTEND});
// setcc for f16x2 and bf16x2 needs special handling to prevent
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -5334,6 +5335,26 @@ static SDValue PerformANDCombine(SDNode *N,
return SDValue();
}
+static SDValue combineBitcast(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
+ const SDValue &Input = N->getOperand(0);
+ const MVT FromVT = Input.getSimpleValueType();
+ const MVT ToVT = N->getSimpleValueType(0);
+ const SDLoc DL(N);
+
+ if (Input.getOpcode() != ISD::BUILD_VECTOR || ToVT != MVT::v2f32 ||
+ !(FromVT.isVector() &&
+ FromVT.getVectorNumElements() == ToVT.getVectorNumElements()))
+ return SDValue();
+
+ const MVT ToEltVT = ToVT.getVectorElementType();
+
+ // pull in build_vector through bitcast
+ return DCI.DAG.getBuildVector(
+ ToVT, DL,
+ {DCI.DAG.getBitcast(ToEltVT, Input.getOperand(0)),
+ DCI.DAG.getBitcast(ToEltVT, Input.getOperand(1))});
+}
+
static SDValue PerformREMCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
CodeGenOptLevel OptLevel) {
@@ -6007,6 +6028,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
return combineADDRSPACECAST(N, DCI);
case ISD::AND:
return PerformANDCombine(N, DCI);
+ case ISD::BITCAST:
+ return combineBitcast(N, DCI);
case ISD::SIGN_EXTEND:
case ISD::ZERO_EXTEND:
return combineMulWide(N, DCI, OptLevel);
diff --git a/llvm/test/CodeGen/NVPTX/f32x2-inlineasm.ll b/llvm/test/CodeGen/NVPTX/f32x2-inlineasm.ll
new file mode 100644
index 0000000000000..4d80ee68faac6
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/f32x2-inlineasm.ll
@@ -0,0 +1,64 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mcpu=sm_90a -O0 -disable-post-ra -frame-pointer=all \
+; RUN: -verify-machineinstrs | FileCheck --check-prefixes=CHECK %s
+; RUN: %if ptxas-12.7 %{ \
+; RUN: llc < %s -mcpu=sm_90a -O0 -disable-post-ra -frame-pointer=all \
+; RUN: -verify-machineinstrs | %ptxas-verify -arch=sm_90a \
+; RUN: %}
+
+target datalayout = "e-m:o-i64:64-i128:128-n32:64-S128"
+target triple = "nvptx64-nvidia-cuda"
+
+define ptx_kernel void @kernel1() {
+; CHECK-LABEL: kernel1(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<11>;
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0: // %bb
+; CHECK-NEXT: mov.b32 %r3, 0;
+; CHECK-NEXT: mov.b32 %r4, %r3;
+; CHECK-NEXT: mov.b32 %r1, %r3;
+; CHECK-NEXT: mov.b32 %r2, %r4;
+; CHECK-NEXT: // begin inline asm
+; CHECK-NEXT: { .reg .pred p; setp.ne.b32 p, 66, 0; wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r1,%r2}, 64, 65, p, 67, 68, 69, 70; }
+; CHECK-EMPTY:
+; CHECK-NEXT: // end inline asm
+; CHECK-NEXT: mov.b32 %r5, %r3;
+; CHECK-NEXT: // begin inline asm
+; CHECK-NEXT: mbarrier.arrive.release.cta.shared::cta.b64 %rd1, [%r5], 1; // XXSTART
+; CHECK-NEXT: // end inline asm
+; CHECK-NEXT: wgmma.wait_group.sync.aligned 0;
+; CHECK-NEXT: mov.b32 %r6, %r3;
+; CHECK-NEXT: // begin inline asm
+; CHECK-NEXT: mbarrier.arrive.release.cta.shared::cta.b64 %rd2, [%r6], 1; // XXEND
+; CHECK-NEXT: // end inline asm
+; CHECK-NEXT: mul.rn.f32 %r7, %r1, 0f00000000;
+; CHECK-NEXT: mul.rn.f32 %r8, %r2, 0f00000000;
+; CHECK-NEXT: add.rn.f32 %r9, %r8, %r7;
+; CHECK-NEXT: shfl.sync.bfly.b32 %r10, %r9, 0, 0, 0;
+; CHECK-NEXT: ret;
+bb:
+ %i = call { <1 x float>, <1 x float> } asm sideeffect "{ .reg .pred p; setp.ne.b32 p, 66, 0; wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1}, 64, 65, p, 67, 68, 69, 70; }\0A", "=f,=f,0,1"(<1 x float> zeroinitializer, <1 x float> zeroinitializer)
+ %i1 = extractvalue { <1 x float>, <1 x float> } %i, 0
+ %i2 = extractvalue { <1 x float>, <1 x float> } %i, 1
+ %i3 = call i64 asm sideeffect " mbarrier.arrive.release.cta.shared::cta.b64 $0, [$1], 1; // XXSTART ", "=l,r"(ptr addrspace(3) null)
+ call void @llvm.nvvm.wgmma.wait_group.sync.aligned(i64 0)
+ %i4 = shufflevector <1 x float> %i1, <1 x float> %i2, <2 x i32> <i32 0, i32 1>
+ %i5 = call i64 asm sideeffect " mbarrier.arrive.release.cta.shared::cta.b64 $0, [$1], 1; // XXEND ", "=l,r"(ptr addrspace(3) null)
+ %i6 = fmul <2 x float> %i4, zeroinitializer
+ %i7 = extractelement <2 x float> %i6, i64 0
+ %i8 = extractelement <2 x float> %i6, i64 1
+ %i9 = fadd float %i8, %i7
+ %i10 = bitcast float %i9 to <1 x i32>
+ %i11 = extractelement <1 x i32> %i10, i64 0
+ %i12 = call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 0, i32 %i11, i32 0, i32 0)
+ ret void
+}
+
+declare void @llvm.nvvm.wgmma.wait_group.sync.aligned(i64 immarg) #0
+
+declare i32 @llvm.nvvm.shfl.sync.bfly.i32(i32, i32, i32, i32) #1
+
+attributes #0 = { convergent nounwind }
+attributes #1 = { convergent nocallback nounwind memory(inaccessiblemem: readwrite) }
|
1e355f1
to
ce79a0d
Compare
e13ac9f
to
99a54b3
Compare
Updated the code and simplified test cases. |
99a54b3
to
8db1673
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks reasonable to me, although I have not had a chance to dig into the issue this is trying to address very deeply. Please wait for @Artem-B to take a look as well.
P.S. I wonder if we should consider just making v2i32 a legal type.
I thought about this as well. I'm hesitant to introduce another legal type and plumb it through NVPTX (although it should be far less effort than v2f32) when there aren't any instructions that support v2i32. If we see more bug reports related to this issue then I'll reconsider it. |
6b7a81a
to
ddab65c
Compare
Looks like CI is red because of a broken Arm/Thumb2 test, but this shouldn't have anything to do with this change. |
@Artem-B ping for review. |
Hi @Artem-B, What are your thoughts on still having this checked in? This doesn't fix your bug but it still improves PTX codegen even in sm1xx. We don't want to be emitting these logical ops under any circumstances. |
It sounds reasonable, but it looks like we're fixing one particular case while we potentially may want a more general fix for construction of small vectors that fit into b32/b64 regardless of their element type. LGTM, but add a TODO that we may need a more general solution. |
I think we shouldn't need to worry about |
ddab65c
to
3155ed2
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
0449a2d
to
d75d057
Compare
I'm not sure I can see it. Do we need to add more test cases to show the benefits? All I see in the test changes is that we prefer b64/v2.b64 loads + splitting moves, over direct loads of v2/v4.b32 into 32-bit regs, and in some cases we're not using 256-bit loads where we should (I think). What am I missing? |
Sorry, my comment was out of date. The "improvements" were miscompiles on <sm_100 that I fixed. |
68ffcd7
to
dbf16b6
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still LGTM, but I'm not sure we need the O3 test run.
Since v2f32 is legal but v2i32 is not, this causes some sequences of operations like bitcast (build_vector) to be lowered inefficiently.
This reverts commit 34859a13d0e4b99d70c250d1dad95392291b61a0.
c2ed0e4
to
9213919
Compare
It appears that the patch may have introduced a bug: https://godbolt.org/z/EPPhGEGxT Looks like we also need to expand |
This patch appears to fix the crash. That said, we should be able to actually lower the truncation v2i32->v2i16 into a single PRMT instruction. LLVM currently generates two PRMT instructions that could be combined. We could use the custom lowering for the op as it would help on all GPU variants.
|
llvm#153478 made v2i32 legal on newer GPUs, but we can not lower all operations yet. Expand the `trunc` operation until we implement efficient lowering.
#153478 made v2i32 legal on newer GPUs, but we can not lower all operations yet. Expand the `trunc/ext` operation until we implement efficient lowering.
Since v2f32 is legal but v2i32 is not, this causes some sequences of operations like bitcast (build_vector) to be lowered inefficiently.
llvm#153478 made v2i32 legal on newer GPUs, but we can not lower all operations yet. Expand the `trunc/ext` operation until we implement efficient lowering.
…lvm#162391) Follow-up on llvm#153478 and llvm#161715. v2i32 register class exists mostly to facilitate v2f32's use of integer registers. There are no actual instructions that can apply to v2i32 directly (except bitwise logical ops). Everything else must be done elementwise.
…lvm#162391) Follow-up on llvm#153478 and llvm#161715. v2i32 register class exists mostly to facilitate v2f32's use of integer registers. There are no actual instructions that can apply to v2i32 directly (except bitwise logical ops). Everything else must be done elementwise.
Since v2f32 is legal but v2i32 is not, this causes some sequences of operations like bitcast (build_vector) to be lowered inefficiently.