Skip to content

Commit 970983f

Browse files
committed
[RISCV] Extend zvqdot matching to handle reduction trees
Now that we have matching for vqdot in it's basic variants, we can extend the matcher to handle reduction trees instead of individual reductions. This is important as we canonicalize reductions by performing a tree in the vector domain before the root reduction instruction. The particular approach taken here has the unfortunate implication that non-matches visit the entire reduction tree once for each time the reduction root is visited in DAG. While conceptually problematic for compile time, this is probably fine in practice as we should only visit the root once per pass of DAGCombine. I don't really see a better solution - suggestions welcome.
1 parent e0a951f commit 970983f

File tree

2 files changed

+145
-47
lines changed

2 files changed

+145
-47
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18131,6 +18131,29 @@ static MVT getQDOTXResultType(MVT OpVT) {
1813118131
return MVT::getVectorVT(MVT::i32, OpEC.divideCoefficientBy(4));
1813218132
}
1813318133

18134+
/// Given fixed length vectors A and B with equal element types, but possibly
18135+
/// different number of elements, return A + B where either A or B is zero
18136+
/// padded to the larger number of elements.
18137+
static SDValue getZeroPaddedAdd(const SDLoc &DL, SDValue A, SDValue B,
18138+
SelectionDAG &DAG) {
18139+
// NOTE: Manually doing the extract/add/insert scheme produces
18140+
// significantly better coegen than the naive pad with zeros
18141+
// and add scheme.
18142+
EVT AVT = A.getValueType();
18143+
EVT BVT = B.getValueType();
18144+
assert(AVT.getVectorElementType() == BVT.getVectorElementType());
18145+
if (AVT.getVectorNumElements() > BVT.getVectorNumElements()) {
18146+
std::swap(A, B);
18147+
std::swap(AVT, BVT);
18148+
}
18149+
18150+
SDValue BPart = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, AVT, B,
18151+
DAG.getVectorIdxConstant(0, DL));
18152+
SDValue Res = DAG.getNode(ISD::ADD, DL, AVT, A, BPart);
18153+
return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, BVT, B, Res,
18154+
DAG.getVectorIdxConstant(0, DL));
18155+
}
18156+
1813418157
static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
1813518158
SelectionDAG &DAG,
1813618159
const RISCVSubtarget &Subtarget,
@@ -18142,6 +18165,26 @@ static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
1814218165
!InVec.getValueType().getVectorElementCount().isKnownMultipleOf(4))
1814318166
return SDValue();
1814418167

18168+
// Recurse through adds (since generic dag canonicalizes to that
18169+
// form).
18170+
if (InVec->getOpcode() == ISD::ADD) {
18171+
SDValue A = InVec.getOperand(0);
18172+
SDValue B = InVec.getOperand(1);
18173+
SDValue AOpt = foldReduceOperandViaVQDOT(A, DL, DAG, Subtarget, TLI);
18174+
SDValue BOpt = foldReduceOperandViaVQDOT(B, DL, DAG, Subtarget, TLI);
18175+
if (AOpt || BOpt) {
18176+
if (AOpt)
18177+
A = AOpt;
18178+
if (BOpt)
18179+
B = BOpt;
18180+
// From here, we're doing A + B with mixed types, implicitly zero
18181+
// padded to the wider type. Note that we *don't* need the result
18182+
// type to be the original VT, and in fact prefer narrower ones
18183+
// if possible.
18184+
return getZeroPaddedAdd(DL, A, B, DAG);
18185+
}
18186+
}
18187+
1814518188
// reduce (zext a) <--> reduce (mul zext a. zext 1)
1814618189
// reduce (sext a) <--> reduce (mul sext a. sext 1)
1814718190
if (InVec.getOpcode() == ISD::ZERO_EXTEND ||

llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll

Lines changed: 102 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -299,17 +299,31 @@ entry:
299299
}
300300

301301
define i32 @vqdot_vv_accum(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
302-
; CHECK-LABEL: vqdot_vv_accum:
303-
; CHECK: # %bb.0: # %entry
304-
; CHECK-NEXT: vsetivli zero, 16, e16, m2, ta, ma
305-
; CHECK-NEXT: vsext.vf2 v10, v8
306-
; CHECK-NEXT: vsext.vf2 v16, v9
307-
; CHECK-NEXT: vwmacc.vv v12, v10, v16
308-
; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
309-
; CHECK-NEXT: vmv.s.x v8, zero
310-
; CHECK-NEXT: vredsum.vs v8, v12, v8
311-
; CHECK-NEXT: vmv.x.s a0, v8
312-
; CHECK-NEXT: ret
302+
; NODOT-LABEL: vqdot_vv_accum:
303+
; NODOT: # %bb.0: # %entry
304+
; NODOT-NEXT: vsetivli zero, 16, e16, m2, ta, ma
305+
; NODOT-NEXT: vsext.vf2 v10, v8
306+
; NODOT-NEXT: vsext.vf2 v16, v9
307+
; NODOT-NEXT: vwmacc.vv v12, v10, v16
308+
; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma
309+
; NODOT-NEXT: vmv.s.x v8, zero
310+
; NODOT-NEXT: vredsum.vs v8, v12, v8
311+
; NODOT-NEXT: vmv.x.s a0, v8
312+
; NODOT-NEXT: ret
313+
;
314+
; DOT-LABEL: vqdot_vv_accum:
315+
; DOT: # %bb.0: # %entry
316+
; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
317+
; DOT-NEXT: vmv.v.i v10, 0
318+
; DOT-NEXT: vqdot.vv v10, v8, v9
319+
; DOT-NEXT: vadd.vv v8, v10, v12
320+
; DOT-NEXT: vsetivli zero, 4, e32, m4, tu, ma
321+
; DOT-NEXT: vmv.v.v v12, v8
322+
; DOT-NEXT: vmv.s.x v8, zero
323+
; DOT-NEXT: vsetivli zero, 16, e32, m4, ta, ma
324+
; DOT-NEXT: vredsum.vs v8, v12, v8
325+
; DOT-NEXT: vmv.x.s a0, v8
326+
; DOT-NEXT: ret
313327
entry:
314328
%a.sext = sext <16 x i8> %a to <16 x i32>
315329
%b.sext = sext <16 x i8> %b to <16 x i32>
@@ -320,17 +334,31 @@ entry:
320334
}
321335

322336
define i32 @vqdotu_vv_accum(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
323-
; CHECK-LABEL: vqdotu_vv_accum:
324-
; CHECK: # %bb.0: # %entry
325-
; CHECK-NEXT: vsetivli zero, 16, e8, m1, ta, ma
326-
; CHECK-NEXT: vwmulu.vv v10, v8, v9
327-
; CHECK-NEXT: vsetvli zero, zero, e16, m2, ta, ma
328-
; CHECK-NEXT: vwaddu.wv v12, v12, v10
329-
; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
330-
; CHECK-NEXT: vmv.s.x v8, zero
331-
; CHECK-NEXT: vredsum.vs v8, v12, v8
332-
; CHECK-NEXT: vmv.x.s a0, v8
333-
; CHECK-NEXT: ret
337+
; NODOT-LABEL: vqdotu_vv_accum:
338+
; NODOT: # %bb.0: # %entry
339+
; NODOT-NEXT: vsetivli zero, 16, e8, m1, ta, ma
340+
; NODOT-NEXT: vwmulu.vv v10, v8, v9
341+
; NODOT-NEXT: vsetvli zero, zero, e16, m2, ta, ma
342+
; NODOT-NEXT: vwaddu.wv v12, v12, v10
343+
; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma
344+
; NODOT-NEXT: vmv.s.x v8, zero
345+
; NODOT-NEXT: vredsum.vs v8, v12, v8
346+
; NODOT-NEXT: vmv.x.s a0, v8
347+
; NODOT-NEXT: ret
348+
;
349+
; DOT-LABEL: vqdotu_vv_accum:
350+
; DOT: # %bb.0: # %entry
351+
; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
352+
; DOT-NEXT: vmv.v.i v10, 0
353+
; DOT-NEXT: vqdotu.vv v10, v8, v9
354+
; DOT-NEXT: vadd.vv v8, v10, v12
355+
; DOT-NEXT: vsetivli zero, 4, e32, m4, tu, ma
356+
; DOT-NEXT: vmv.v.v v12, v8
357+
; DOT-NEXT: vmv.s.x v8, zero
358+
; DOT-NEXT: vsetivli zero, 16, e32, m4, ta, ma
359+
; DOT-NEXT: vredsum.vs v8, v12, v8
360+
; DOT-NEXT: vmv.x.s a0, v8
361+
; DOT-NEXT: ret
334362
entry:
335363
%a.zext = zext <16 x i8> %a to <16 x i32>
336364
%b.zext = zext <16 x i8> %b to <16 x i32>
@@ -341,17 +369,31 @@ entry:
341369
}
342370

343371
define i32 @vqdotsu_vv_accum(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
344-
; CHECK-LABEL: vqdotsu_vv_accum:
345-
; CHECK: # %bb.0: # %entry
346-
; CHECK-NEXT: vsetivli zero, 16, e16, m2, ta, ma
347-
; CHECK-NEXT: vsext.vf2 v10, v8
348-
; CHECK-NEXT: vzext.vf2 v16, v9
349-
; CHECK-NEXT: vwmaccsu.vv v12, v10, v16
350-
; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
351-
; CHECK-NEXT: vmv.s.x v8, zero
352-
; CHECK-NEXT: vredsum.vs v8, v12, v8
353-
; CHECK-NEXT: vmv.x.s a0, v8
354-
; CHECK-NEXT: ret
372+
; NODOT-LABEL: vqdotsu_vv_accum:
373+
; NODOT: # %bb.0: # %entry
374+
; NODOT-NEXT: vsetivli zero, 16, e16, m2, ta, ma
375+
; NODOT-NEXT: vsext.vf2 v10, v8
376+
; NODOT-NEXT: vzext.vf2 v16, v9
377+
; NODOT-NEXT: vwmaccsu.vv v12, v10, v16
378+
; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma
379+
; NODOT-NEXT: vmv.s.x v8, zero
380+
; NODOT-NEXT: vredsum.vs v8, v12, v8
381+
; NODOT-NEXT: vmv.x.s a0, v8
382+
; NODOT-NEXT: ret
383+
;
384+
; DOT-LABEL: vqdotsu_vv_accum:
385+
; DOT: # %bb.0: # %entry
386+
; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
387+
; DOT-NEXT: vmv.v.i v10, 0
388+
; DOT-NEXT: vqdotsu.vv v10, v8, v9
389+
; DOT-NEXT: vadd.vv v8, v10, v12
390+
; DOT-NEXT: vsetivli zero, 4, e32, m4, tu, ma
391+
; DOT-NEXT: vmv.v.v v12, v8
392+
; DOT-NEXT: vmv.s.x v8, zero
393+
; DOT-NEXT: vsetivli zero, 16, e32, m4, ta, ma
394+
; DOT-NEXT: vredsum.vs v8, v12, v8
395+
; DOT-NEXT: vmv.x.s a0, v8
396+
; DOT-NEXT: ret
355397
entry:
356398
%a.sext = sext <16 x i8> %a to <16 x i32>
357399
%b.zext = zext <16 x i8> %b to <16 x i32>
@@ -455,20 +497,33 @@ entry:
455497
}
456498

457499
define i32 @vqdot_vv_split(<16 x i8> %a, <16 x i8> %b, <16 x i8> %c, <16 x i8> %d) {
458-
; CHECK-LABEL: vqdot_vv_split:
459-
; CHECK: # %bb.0: # %entry
460-
; CHECK-NEXT: vsetivli zero, 16, e16, m2, ta, ma
461-
; CHECK-NEXT: vsext.vf2 v12, v8
462-
; CHECK-NEXT: vsext.vf2 v14, v9
463-
; CHECK-NEXT: vsext.vf2 v16, v10
464-
; CHECK-NEXT: vsext.vf2 v18, v11
465-
; CHECK-NEXT: vwmul.vv v8, v12, v14
466-
; CHECK-NEXT: vwmacc.vv v8, v16, v18
467-
; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
468-
; CHECK-NEXT: vmv.s.x v12, zero
469-
; CHECK-NEXT: vredsum.vs v8, v8, v12
470-
; CHECK-NEXT: vmv.x.s a0, v8
471-
; CHECK-NEXT: ret
500+
; NODOT-LABEL: vqdot_vv_split:
501+
; NODOT: # %bb.0: # %entry
502+
; NODOT-NEXT: vsetivli zero, 16, e16, m2, ta, ma
503+
; NODOT-NEXT: vsext.vf2 v12, v8
504+
; NODOT-NEXT: vsext.vf2 v14, v9
505+
; NODOT-NEXT: vsext.vf2 v16, v10
506+
; NODOT-NEXT: vsext.vf2 v18, v11
507+
; NODOT-NEXT: vwmul.vv v8, v12, v14
508+
; NODOT-NEXT: vwmacc.vv v8, v16, v18
509+
; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma
510+
; NODOT-NEXT: vmv.s.x v12, zero
511+
; NODOT-NEXT: vredsum.vs v8, v8, v12
512+
; NODOT-NEXT: vmv.x.s a0, v8
513+
; NODOT-NEXT: ret
514+
;
515+
; DOT-LABEL: vqdot_vv_split:
516+
; DOT: # %bb.0: # %entry
517+
; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
518+
; DOT-NEXT: vmv.v.i v12, 0
519+
; DOT-NEXT: vmv.v.i v13, 0
520+
; DOT-NEXT: vqdot.vv v12, v8, v9
521+
; DOT-NEXT: vqdot.vv v13, v10, v11
522+
; DOT-NEXT: vadd.vv v8, v12, v13
523+
; DOT-NEXT: vmv.s.x v9, zero
524+
; DOT-NEXT: vredsum.vs v8, v8, v9
525+
; DOT-NEXT: vmv.x.s a0, v8
526+
; DOT-NEXT: ret
472527
entry:
473528
%a.sext = sext <16 x i8> %a to <16 x i32>
474529
%b.sext = sext <16 x i8> %b to <16 x i32>

0 commit comments

Comments
 (0)