Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 40 additions & 6 deletions llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@

#include "llvm/CodeGen/ComplexDeinterleavingPass.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/TargetTransformInfo.h"
Expand Down Expand Up @@ -274,6 +275,13 @@ class ComplexDeinterleavingGraph {
/// `llvm.vector.reduce.fadd` when unroll factor isn't one.
MapVector<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo;

/// In the case of reductions in unrolled loops, the %OutsideUser from
/// ReductionInfo is an add instruction that precedes the reduction.
/// UnrollInfo pairs values together if they are both operands of the same
/// add. This pairing info is then used to add the resulting complex
/// operations together before the final reduction.
MapVector<Value *, Value *> UnrollInfo;

/// In the process of detecting a reduction, we consider a pair of
/// %ReductionOP, which we refer to as real and imag (or vice versa), and
/// traverse the use-tree to detect complex operations. As this is a reduction
Expand Down Expand Up @@ -2253,8 +2261,31 @@ void ComplexDeinterleavingGraph::processReductionSingle(
auto *FinalReduction = ReductionInfo[Real].second;
Builder.SetInsertPoint(&*FinalReduction->getParent()->getFirstInsertionPt());

auto *AddReduce = Builder.CreateAddReduce(OperationReplacement);
Value *Other;
bool EraseFinalReductionHere = false;
if (match(FinalReduction, m_c_Add(m_Specific(Real), m_Value(Other)))) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I replace the add in this test, by a sub, the pass still crashes, so this is not sufficient.
Does it matter what the operation (the one outside the loop) actually is?

I would have expected something like this:

define <vscale x 4 x i32> @cdotp_i8_rot0(<vscale x 32 x i8> %a, <vscale x 32 x i8> %b) {
entry:
  br label %vector.body

vector.body:                                      ; preds = %vector.body, %entry                                                                           
  %vec.phi = phi <vscale x 4 x i32> [ zeroinitializer, %entry ], [ %partial.reduce.sub, %vector.body ]                                                     
  %a.deinterleaved = call { <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.vector.deinterleave2.v32i8(<vscale x 32 x i8> %a)                               
  %b.deinterleaved = call { <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.vector.deinterleave2.v32i8(<vscale x 32 x i8> %b)                               
  %a.real = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %a.deinterleaved, 0                                                                    
  %a.imag = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %a.deinterleaved, 1                                                                    
  %b.real = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %b.deinterleaved, 0
  %b.imag = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %b.deinterleaved, 1                                                                    
  %a.real.ext = sext <vscale x 16 x i8> %a.real to <vscale x 16 x i32>                                                                                     
  %a.imag.ext = sext <vscale x 16 x i8> %a.imag to <vscale x 16 x i32>                                                                                     
  %b.real.ext = sext <vscale x 16 x i8> %b.real to <vscale x 16 x i32>
  %b.imag.ext = sext <vscale x 16 x i8> %b.imag to <vscale x 16 x i32>
  %real.mul = mul <vscale x 16 x i32> %b.real.ext, %a.real.ext
  %real.mul.reduced = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %vec.phi, <vscale x 16 x i32> %real.mul)
  %imag.mul = mul <vscale x 16 x i32> %b.imag.ext, %a.imag.ext
  %imag.mul.neg = sub <vscale x 16 x i32> zeroinitializer, %imag.mul
  %partial.reduce.sub = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %real.mul.reduced, <vscale x 16 x i32> %imag.mul.neg)
  br i1 true, label %middle.block, label %vector.body

middle.block:                                     ; preds = %vector.body
  ret <vscale x 4 x i32> %partial.reduce.sub
}

to use cdot instructions as well, but this case also seems to crash. This suggests that the issue is not to do with unrolling, but rather with the user outside the loop being anything else than a reduction?

Copy link
Contributor Author

@NickGuy-Arm NickGuy-Arm Mar 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a separate issue; I wouldn't expect that snippet to be processed as the current implementation would require changing the function return type. Whereas the issue this patch is aimed at fixing is when it tries to change one operand of an add with a value of a different type.

If this were to instead reinterleave and store the complex result in middle.block, instead of returning it, then I would expect the pass to process it and emit cdot instructions.

When I replace the add in this test, by a sub, the pass still crashes, so this is not sufficient.

I'm not sure if the loop vectorizer would ever emit a sub here. Please do correct me if I'm wrong, but I'm not seeing any VECREDUCE_ADD or vecreduce.add equivalent for subtraction, and the instruction of %bin.rdx in this case is derived from the reduction intrinsic.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For regular reductions (without cdot), we needed to analyse and rewrite use outside of the loop due to Real and Imaginary part extraction. See cases in complex-deinterleaving-reductions.ll. But for cdot, we don't need to do any of that. Here's a test from complex-deinterleaving-cdot.ll:

define i32 @cdotp_i8_rot0(<vscale x 32 x i8> %a, <vscale x 32 x i8> %b) {
entry:
  br label %vector.body

vector.body:                                      ; preds = %vector.body, %entry
  %vec.phi = phi <vscale x 4 x i32> [ zeroinitializer, %entry ], [ %partial.reduce.sub, %vector.body ]
  %a.deinterleaved = call { <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.vector.deinterleave2.nxv32i8(<vscale x 32 x i8> %a)
  %b.deinterleaved = call { <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.vector.deinterleave2.nxv32i8(<vscale x 32 x i8> %b)
  %a.real = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %a.deinterleaved, 0
  %a.imag = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %a.deinterleaved, 1
  %b.real = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %b.deinterleaved, 0
  %b.imag = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %b.deinterleaved, 1
  %a.real.ext = sext <vscale x 16 x i8> %a.real to <vscale x 16 x i32>
  %a.imag.ext = sext <vscale x 16 x i8> %a.imag to <vscale x 16 x i32>
  %b.real.ext = sext <vscale x 16 x i8> %b.real to <vscale x 16 x i32>
  %b.imag.ext = sext <vscale x 16 x i8> %b.imag to <vscale x 16 x i32>
  %real.mul = mul <vscale x 16 x i32> %b.real.ext, %a.real.ext
  %real.mul.reduced = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %vec.phi, <vscale x 16 x i32> %real.mul)
  %imag.mul = mul <vscale x 16 x i32> %b.imag.ext, %a.imag.ext
  %imag.mul.neg = sub <vscale x 16 x i32> zeroinitializer, %imag.mul
  %partial.reduce.sub = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %real.mul.reduced, <vscale x 16 x i32> %imag.mul.neg)
  br i1 true, label %middle.block, label %vector.body

middle.block:                                     ; preds = %vector.body
  %0 = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> %partial.reduce.sub)
  ret i32 %0
}

It is currently transformed into:

efine i32 @cdotp_i8_rot0(<vscale x 32 x i8> %a, <vscale x 32 x i8> %b) #0 {
entry:
  br label %vector.body

vector.body:                                      ; preds = %vector.body, %entry
  %0 = phi <vscale x 8 x i32> [ zeroinitializer, %entry ], [ %10, %vector.body ]
  %1 = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> %a, i64 0)
  %2 = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> %b, i64 0)
  %3 = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> %a, i64 16)
  %4 = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> %b, i64 16)
  %5 = call <vscale x 4 x i32> @llvm.vector.extract.nxv4i32.nxv8i32(<vscale x 8 x i32> %0, i64 0)
  %6 = call <vscale x 4 x i32> @llvm.vector.extract.nxv4i32.nxv8i32(<vscale x 8 x i32> %0, i64 4)
  %7 = call <vscale x 4 x i32> @llvm.aarch64.sve.cdot.nxv4i32(<vscale x 4 x i32> %5, <vscale x 16 x i8> %1, <vscale x 16 x i8> %2, i32 0)
  %8 = call <vscale x 4 x i32> @llvm.aarch64.sve.cdot.nxv4i32(<vscale x 4 x i32> %6, <vscale x 16 x i8> %3, <vscale x 16 x i8> %4, i32 0)
  %9 = call <vscale x 8 x i32> @llvm.vector.insert.nxv8i32.nxv4i32(<vscale x 8 x i32> poison, <vscale x 4 x i32> %7, i64 0)
  %10 = call <vscale x 8 x i32> @llvm.vector.insert.nxv8i32.nxv4i32(<vscale x 8 x i32> %9, <vscale x 4 x i32> %8, i64 4)
  br i1 true, label %middle.block, label %vector.body

middle.block:                                     ; preds = %vector.body
  %11 = call i32 @llvm.vector.reduce.add.nxv8i32(<vscale x 8 x i32> %10)
  ret i32 %11
}

But instead, we could ignore everything happening after the final llvm.experimental.vector.partial.reduce.add and just put one cdot on another:

define i32 @cdotp_i8_rot0(<vscale x 32 x i8> %a, <vscale x 32 x i8> %b) #0 {
entry:
  br label %vector.body

vector.body:                                      ; preds = %vector.body, %entry
  %0 = phi <vscale x 4 x i32> [ zeroinitializer, %entry ], [ %8, %vector.body ]
  %1 = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> %a, i64 0)
  %2 = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> %b, i64 0)
  %3 = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> %a, i64 16)
  %4 = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> %b, i64 16)
  %7 = call <vscale x 4 x i32> @llvm.aarch64.sve.cdot.nxv4i32(<vscale x 4 x i32> %0, <vscale x 16 x i8> %1, <vscale x 16 x i8> %2, i32 0)
  %8 = call <vscale x 4 x i32> @llvm.aarch64.sve.cdot.nxv4i32(<vscale x 4 x i32> %7, <vscale x 16 x i8> %3, <vscale x 16 x i8> %4, i32 0)
  br i1 true, label %middle.block, label %vector.body

middle.block:                                     ; preds = %vector.body
  %result = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> %8)
  ret i32 %result
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @igogo-x86, I agree it makes more sense to feed the result from one cdot into the other, rather than changing the PHI node to be a wider type.

My understanding is that the Complex Deinterleaving pass was created because for certain operations (like reductions) where there is both a PHI+reduction for the imaginary and one PHI+reduction for the real part, it is better to keep the intermediate values interleaved, because the cmla instruction takes interleaved tuples as input and returns interleaved tuples as output. This avoids having to deinterleave values first and it also allows using specialised cmla instructions to do the complex MLA operation. The reduction PHI then contains a vector of <(r, i), (r, i), ..> tuples, which need de-interleaving only when doing the final reduction.
For the case of cdot instructions there is no need for this, because the result vector will always be deinterleaved (the cdot instruction returns either a widened real, or a widened imaginary result).

If that understanding is correct, then I don't really see a need to implement this optimization in the ComplexDeinterleave pass. This looks more like a DAGcombine of partialreduce(mul(ext(deinterleave(a)), ext(deinterleave(b))) -> cdot(a, b, #0) (with some variation of this pattern for other rotations). With the new ISD node ISD::PARTIAL_REDUCE_[U|S]MLA added by @JamesChesterman this should be even easier to identify.

Please let me know if I'm missing anything here though.

UnrollInfo[Real] = OperationReplacement;
if (!UnrollInfo.contains(Other) || !FinalReduction->hasOneUser())
return;

auto *User = *FinalReduction->user_begin();
if (!match(User, m_Intrinsic<Intrinsic::vector_reduce_add>()))
return;

FinalReduction = cast<Instruction>(User);
Builder.SetInsertPoint(FinalReduction);
OperationReplacement =
Builder.CreateAdd(OperationReplacement, UnrollInfo[Other]);

UnrollInfo.erase(Real);
UnrollInfo.erase(Other);
EraseFinalReductionHere = true;
}

Value *AddReduce = Builder.CreateAddReduce(OperationReplacement);
FinalReduction->replaceAllUsesWith(AddReduce);
if (EraseFinalReductionHere)
FinalReduction->eraseFromParent();
}

void ComplexDeinterleavingGraph::processReductionOperation(
Expand Down Expand Up @@ -2299,7 +2330,7 @@ void ComplexDeinterleavingGraph::processReductionOperation(
}

void ComplexDeinterleavingGraph::replaceNodes() {
SmallVector<Instruction *, 16> DeadInstrRoots;
SmallSetVector<Instruction *, 16> DeadInstrRoots;
for (auto *RootInstruction : OrderedRoots) {
// Check if this potential root went through check process and we can
// deinterleave it
Expand All @@ -2316,20 +2347,23 @@ void ComplexDeinterleavingGraph::replaceNodes() {
auto *RootImag = cast<Instruction>(RootNode->Imag);
ReductionInfo[RootReal].first->removeIncomingValue(BackEdge);
ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);
DeadInstrRoots.push_back(RootReal);
DeadInstrRoots.push_back(RootImag);
DeadInstrRoots.insert(RootReal);
DeadInstrRoots.insert(RootImag);
} else if (RootNode->Operation ==
ComplexDeinterleavingOperation::ReductionSingle) {
auto *RootInst = cast<Instruction>(RootNode->Real);
ReductionInfo[RootInst].first->removeIncomingValue(BackEdge);
DeadInstrRoots.push_back(ReductionInfo[RootInst].second);
DeadInstrRoots.insert(ReductionInfo[RootInst].second);
} else {
assert(R && "Unable to find replacement for RootInstruction");
DeadInstrRoots.push_back(RootInstruction);
DeadInstrRoots.insert(RootInstruction);
RootInstruction->replaceAllUsesWith(R);
}
}

assert(UnrollInfo.empty() &&
"UnrollInfo should be empty after replacing all nodes");

for (auto *I : DeadInstrRoots)
RecursivelyDeleteTriviallyDeadInstructions(I, TLI);
}
Loading