Skip to content

Commit 3e34b3e

Browse files
committed
[NVPTX] Add DAG combine patterns to simplify IMAD
I noticed that NVPTX will sometimes emit `mad.lo` to multiply by 1, e.g. in https://gcc.godbolt.org/z/45W3Wcnxz This happens when DAGCombiner operates on the add before the mul, so the imad contraction happens regardless of whether the mul could have been simplified. This PR adds: ``` mad x 1 y => add x y mad x -1 y => sub y x mad x 0 y => y mad x y 0 => mul x y mad c0 c1 z => add z (C0 * C1) ``` Another option might be to remove `NVPTXISD::IMAD` and only combine to mad during selection. This would allow the normal DAGCombiner patterns to simplify the graph without any NVPTX-specific intervention. However, it also risks DAGCombiner breaking up the mul-add patterns, which is why I haven't done it that way.
1 parent 9abcca5 commit 3e34b3e

File tree

3 files changed

+70
-1
lines changed

3 files changed

+70
-1
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5164,6 +5164,53 @@ PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
51645164
return DAG.getNode(ISD::BITCAST, DL, VT, PRMT);
51655165
}
51665166

5167+
static SDValue
5168+
PerformIMADCombineWithOperands(SDNode *N, SDValue N0, SDValue N1, SDValue N2,
5169+
TargetLowering::DAGCombinerInfo &DCI) {
5170+
ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
5171+
ConstantSDNode *N2C = dyn_cast<ConstantSDNode>(N2);
5172+
EVT VT = N0->getValueType(0);
5173+
SDLoc DL(N);
5174+
SDNodeFlags Flags = N->getFlags();
5175+
5176+
// mad x 1 y => add x y
5177+
if (N1C && N1C->isOne())
5178+
return DCI.DAG.getNode(ISD::ADD, DL, VT, N0, N2, Flags);
5179+
5180+
// mad x -1 y => sub y x
5181+
if (N1C && N1C->isAllOnes()) {
5182+
Flags.setNoUnsignedWrap(false);
5183+
return DCI.DAG.getNode(ISD::SUB, DL, VT, N2, N0, Flags);
5184+
}
5185+
5186+
// mad x 0 y => y
5187+
if (N1C && N1C->isZero())
5188+
return N2;
5189+
5190+
// mad x y 0 => mul x y
5191+
if (N2C && N2C->isZero())
5192+
return DCI.DAG.getNode(ISD::MUL, DL, VT, N0, N1, Flags);
5193+
5194+
// mad c0 c1 x => add x (c0*c1)
5195+
if (SDValue C =
5196+
DCI.DAG.FoldConstantArithmetic(ISD::MUL, DL, VT, {N0, N1}, Flags))
5197+
return DCI.DAG.getNode(ISD::ADD, DL, VT, N2, C, Flags);
5198+
5199+
return {};
5200+
}
5201+
5202+
static SDValue PerformIMADCombine(SDNode *N,
5203+
TargetLowering::DAGCombinerInfo &DCI) {
5204+
SDValue N0 = N->getOperand(0);
5205+
SDValue N1 = N->getOperand(1);
5206+
SDValue N2 = N->getOperand(2);
5207+
SDValue res = PerformIMADCombineWithOperands(N, N0, N1, N2, DCI);
5208+
if (res)
5209+
return res;
5210+
5211+
return PerformIMADCombineWithOperands(N, N1, N0, N2, DCI);
5212+
}
5213+
51675214
SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
51685215
DAGCombinerInfo &DCI) const {
51695216
CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
@@ -5198,6 +5245,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
51985245
return PerformVSELECTCombine(N, DCI);
51995246
case ISD::BUILD_VECTOR:
52005247
return PerformBUILD_VECTORCombine(N, DCI);
5248+
case NVPTXISD::IMAD:
5249+
return PerformIMADCombine(N, DCI);
52015250
}
52025251
return SDValue();
52035252
}

llvm/test/CodeGen/NVPTX/combine-mad.ll

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,3 +183,23 @@ define i32 @test4_rev(i32 %a, i32 %b, i32 %c, i1 %p) {
183183
%add = add i32 %c, %sel
184184
ret i32 %add
185185
}
186+
187+
;; This case relies on mad x 1 y => add x y, previously we emit:
188+
;; mad.lo.s32 %r3, %r1, 1, %r2;
189+
define i32 @test_mad_fold(i32 %x) {
190+
; CHECK-LABEL: test_mad_fold(
191+
; CHECK: {
192+
; CHECK-NEXT: .reg .b32 %r<7>;
193+
; CHECK-EMPTY:
194+
; CHECK-NEXT: // %bb.0:
195+
; CHECK-NEXT: ld.param.u32 %r1, [test_mad_fold_param_0];
196+
; CHECK-NEXT: mul.hi.s32 %r2, %r1, -2147221471;
197+
; CHECK-NEXT: add.s32 %r3, %r1, %r2;
198+
; CHECK-NEXT: shr.u32 %r4, %r3, 31;
199+
; CHECK-NEXT: shr.s32 %r5, %r3, 12;
200+
; CHECK-NEXT: add.s32 %r6, %r5, %r4;
201+
; CHECK-NEXT: st.param.b32 [func_retval0], %r6;
202+
; CHECK-NEXT: ret;
203+
%div = sdiv i32 %x, 8191
204+
ret i32 %div
205+
}

llvm/test/CodeGen/NVPTX/dynamic_stackalloc.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
; CHECK-NOT: __local_depot
1313

1414
; CHECK-32: ld.param.u32 %r[[SIZE:[0-9]]], [test_dynamic_stackalloc_param_0];
15-
; CHECK-32-NEXT: mad.lo.s32 %r[[SIZE2:[0-9]]], %r[[SIZE]], 1, 7;
15+
; CHECK-32-NEXT: add.s32 %r[[SIZE2:[0-9]]], %r[[SIZE]], 7;
1616
; CHECK-32-NEXT: and.b32 %r[[SIZE3:[0-9]]], %r[[SIZE2]], -8;
1717
; CHECK-32-NEXT: alloca.u32 %r[[ALLOCA:[0-9]]], %r[[SIZE3]], 16;
1818
; CHECK-32-NEXT: cvta.local.u32 %r[[ALLOCA]], %r[[ALLOCA]];

0 commit comments

Comments
 (0)