Skip to content

Commit 448bfce

Browse files
Merge branch 'main' into extend-offloading-api-llvm-objcopy
2 parents e493a07 + 33c9236 commit 448bfce

File tree

2 files changed

+133
-9
lines changed

2 files changed

+133
-9
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19123,18 +19123,18 @@ static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
1912319123
SelectionDAG &DAG,
1912419124
const RISCVSubtarget &Subtarget,
1912519125
const RISCVTargetLowering &TLI) {
19126+
using namespace SDPatternMatch;
1912619127
// Note: We intentionally do not check the legality of the reduction type.
1912719128
// We want to handle the m4/m8 *src* types, and thus need to let illegal
1912819129
// intermediate types flow through here.
1912919130
if (InVec.getValueType().getVectorElementType() != MVT::i32 ||
1913019131
!InVec.getValueType().getVectorElementCount().isKnownMultipleOf(4))
1913119132
return SDValue();
1913219133

19133-
// Recurse through adds (since generic dag canonicalizes to that
19134-
// form). TODO: Handle disjoint or here.
19135-
if (InVec->getOpcode() == ISD::ADD) {
19136-
SDValue A = InVec.getOperand(0);
19137-
SDValue B = InVec.getOperand(1);
19134+
// Recurse through adds/disjoint ors (since generic dag canonicalizes to that
19135+
// form).
19136+
SDValue A, B;
19137+
if (sd_match(InVec, m_AddLike(m_Value(A), m_Value(B)))) {
1913819138
SDValue AOpt = foldReduceOperandViaVQDOT(A, DL, DAG, Subtarget, TLI);
1913919139
SDValue BOpt = foldReduceOperandViaVQDOT(B, DL, DAG, Subtarget, TLI);
1914019140
if (AOpt || BOpt) {
@@ -19171,12 +19171,9 @@ static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
1917119171
// mul (zext a, zext b) -> partial_reduce_umla 0, a, b
1917219172
// mul (sext a, zext b) -> partial_reduce_ssmla 0, a, b
1917319173
// mul (zext a, sext b) -> partial_reduce_smla 0, b, a (swapped)
19174-
if (InVec.getOpcode() != ISD::MUL)
19174+
if (!sd_match(InVec, m_Mul(m_Value(A), m_Value(B))))
1917519175
return SDValue();
1917619176

19177-
SDValue A = InVec.getOperand(0);
19178-
SDValue B = InVec.getOperand(1);
19179-
1918019177
if (!ISD::isExtOpcode(A.getOpcode()))
1918119178
return SDValue();
1918219179

llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1552,6 +1552,133 @@ entry:
15521552
%res = call <4 x i32> @llvm.experimental.vector.partial.reduce.add(<4 x i32> zeroinitializer, <16 x i32> %a.ext)
15531553
ret <4 x i32> %res
15541554
}
1555+
1556+
define i32 @vqdot_vv_accum_disjoint_or(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
1557+
; NODOT-LABEL: vqdot_vv_accum_disjoint_or:
1558+
; NODOT: # %bb.0: # %entry
1559+
; NODOT-NEXT: vsetivli zero, 16, e16, m2, ta, ma
1560+
; NODOT-NEXT: vsext.vf2 v16, v8
1561+
; NODOT-NEXT: vsext.vf2 v18, v9
1562+
; NODOT-NEXT: vwmul.vv v8, v16, v18
1563+
; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma
1564+
; NODOT-NEXT: vor.vv v8, v8, v12
1565+
; NODOT-NEXT: vmv.s.x v12, zero
1566+
; NODOT-NEXT: vredsum.vs v8, v8, v12
1567+
; NODOT-NEXT: vmv.x.s a0, v8
1568+
; NODOT-NEXT: ret
1569+
;
1570+
; DOT-LABEL: vqdot_vv_accum_disjoint_or:
1571+
; DOT: # %bb.0: # %entry
1572+
; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
1573+
; DOT-NEXT: vmv1r.v v16, v12
1574+
; DOT-NEXT: vqdot.vv v16, v8, v9
1575+
; DOT-NEXT: vsetivli zero, 4, e32, m4, tu, ma
1576+
; DOT-NEXT: vmv.v.v v12, v16
1577+
; DOT-NEXT: vmv.s.x v8, zero
1578+
; DOT-NEXT: vsetivli zero, 16, e32, m4, ta, ma
1579+
; DOT-NEXT: vredsum.vs v8, v12, v8
1580+
; DOT-NEXT: vmv.x.s a0, v8
1581+
; DOT-NEXT: ret
1582+
entry:
1583+
%a.sext = sext <16 x i8> %a to <16 x i32>
1584+
%b.sext = sext <16 x i8> %b to <16 x i32>
1585+
%mul = mul <16 x i32> %a.sext, %b.sext
1586+
%add = or disjoint <16 x i32> %mul, %x
1587+
%sum = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %add)
1588+
ret i32 %sum
1589+
}
1590+
1591+
define i32 @vqdot_vv_accum_or(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
1592+
; CHECK-LABEL: vqdot_vv_accum_or:
1593+
; CHECK: # %bb.0: # %entry
1594+
; CHECK-NEXT: vsetivli zero, 16, e16, m2, ta, ma
1595+
; CHECK-NEXT: vsext.vf2 v16, v8
1596+
; CHECK-NEXT: vsext.vf2 v18, v9
1597+
; CHECK-NEXT: vwmul.vv v8, v16, v18
1598+
; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
1599+
; CHECK-NEXT: vor.vv v8, v8, v12
1600+
; CHECK-NEXT: vmv.s.x v12, zero
1601+
; CHECK-NEXT: vredsum.vs v8, v8, v12
1602+
; CHECK-NEXT: vmv.x.s a0, v8
1603+
; CHECK-NEXT: ret
1604+
entry:
1605+
%a.sext = sext <16 x i8> %a to <16 x i32>
1606+
%b.sext = sext <16 x i8> %b to <16 x i32>
1607+
%mul = mul <16 x i32> %a.sext, %b.sext
1608+
%add = or <16 x i32> %mul, %x
1609+
%sum = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %add)
1610+
ret i32 %sum
1611+
}
1612+
1613+
define i32 @vqdotu_vv_accum_disjoint_or(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
1614+
; NODOT-LABEL: vqdotu_vv_accum_disjoint_or:
1615+
; NODOT: # %bb.0: # %entry
1616+
; NODOT-NEXT: vsetivli zero, 16, e8, m1, ta, ma
1617+
; NODOT-NEXT: vwmulu.vv v10, v8, v9
1618+
; NODOT-NEXT: vsetvli zero, zero, e16, m2, ta, ma
1619+
; NODOT-NEXT: vwaddu.wv v12, v12, v10
1620+
; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma
1621+
; NODOT-NEXT: vmv.s.x v8, zero
1622+
; NODOT-NEXT: vredsum.vs v8, v12, v8
1623+
; NODOT-NEXT: vmv.x.s a0, v8
1624+
; NODOT-NEXT: ret
1625+
;
1626+
; DOT-LABEL: vqdotu_vv_accum_disjoint_or:
1627+
; DOT: # %bb.0: # %entry
1628+
; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
1629+
; DOT-NEXT: vmv1r.v v16, v12
1630+
; DOT-NEXT: vqdotu.vv v16, v8, v9
1631+
; DOT-NEXT: vsetivli zero, 4, e32, m4, tu, ma
1632+
; DOT-NEXT: vmv.v.v v12, v16
1633+
; DOT-NEXT: vmv.s.x v8, zero
1634+
; DOT-NEXT: vsetivli zero, 16, e32, m4, ta, ma
1635+
; DOT-NEXT: vredsum.vs v8, v12, v8
1636+
; DOT-NEXT: vmv.x.s a0, v8
1637+
; DOT-NEXT: ret
1638+
entry:
1639+
%a.zext = zext <16 x i8> %a to <16 x i32>
1640+
%b.zext = zext <16 x i8> %b to <16 x i32>
1641+
%mul = mul <16 x i32> %a.zext, %b.zext
1642+
%add = or disjoint <16 x i32> %mul, %x
1643+
%sum = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %add)
1644+
ret i32 %sum
1645+
}
1646+
1647+
define i32 @vqdotsu_vv_accum_disjoint_or(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
1648+
; NODOT-LABEL: vqdotsu_vv_accum_disjoint_or:
1649+
; NODOT: # %bb.0: # %entry
1650+
; NODOT-NEXT: vsetivli zero, 16, e16, m2, ta, ma
1651+
; NODOT-NEXT: vsext.vf2 v16, v8
1652+
; NODOT-NEXT: vzext.vf2 v18, v9
1653+
; NODOT-NEXT: vwmulsu.vv v8, v16, v18
1654+
; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma
1655+
; NODOT-NEXT: vor.vv v8, v8, v12
1656+
; NODOT-NEXT: vmv.s.x v12, zero
1657+
; NODOT-NEXT: vredsum.vs v8, v8, v12
1658+
; NODOT-NEXT: vmv.x.s a0, v8
1659+
; NODOT-NEXT: ret
1660+
;
1661+
; DOT-LABEL: vqdotsu_vv_accum_disjoint_or:
1662+
; DOT: # %bb.0: # %entry
1663+
; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
1664+
; DOT-NEXT: vmv1r.v v16, v12
1665+
; DOT-NEXT: vqdotsu.vv v16, v8, v9
1666+
; DOT-NEXT: vsetivli zero, 4, e32, m4, tu, ma
1667+
; DOT-NEXT: vmv.v.v v12, v16
1668+
; DOT-NEXT: vmv.s.x v8, zero
1669+
; DOT-NEXT: vsetivli zero, 16, e32, m4, ta, ma
1670+
; DOT-NEXT: vredsum.vs v8, v12, v8
1671+
; DOT-NEXT: vmv.x.s a0, v8
1672+
; DOT-NEXT: ret
1673+
entry:
1674+
%a.sext = sext <16 x i8> %a to <16 x i32>
1675+
%b.zext = zext <16 x i8> %b to <16 x i32>
1676+
%mul = mul <16 x i32> %a.sext, %b.zext
1677+
%add = or disjoint <16 x i32> %mul, %x
1678+
%sum = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %add)
1679+
ret i32 %sum
1680+
}
1681+
15551682
;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
15561683
; DOT32: {{.*}}
15571684
; DOT64: {{.*}}

0 commit comments

Comments
 (0)