Skip to content

Conversation

@sushgokh
Copy link
Contributor

natively supported on Neon and SVE

PR #158641 refined and refactored the cost model for partial reductions. While doing so, it missed out on certain constraints. Specifically, cases like i32 -> i64 partial reduce are not natively supported. This patch adds back the condition/constraint that was present before PR #158641

natively supported on Neon and SVE

PR llvm#158641 refined and refactored the cost model for partial reductions.
While doing so, it missed out on certain constraints. Specifically,
cases like i32 -> i64 partial reduce are not natively supported. This
patch adds back the condition/constraint that was present before PR llvm#158641
@llvmbot
Copy link
Member

llvmbot commented Oct 16, 2025

@llvm/pr-subscribers-backend-aarch64

@llvm/pr-subscribers-llvm-transforms

Author: Sushant Gokhale (sushgokh)

Changes

natively supported on Neon and SVE

PR #158641 refined and refactored the cost model for partial reductions. While doing so, it missed out on certain constraints. Specifically, cases like i32 -> i64 partial reduce are not natively supported. This patch adds back the condition/constraint that was present before PR #158641


Full diff: https://github.com/llvm/llvm-project/pull/163728.diff

2 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (+3)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce.ll (+41-41)
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 479e34515fc8a..bfb2cacbd7dca 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -5661,6 +5661,9 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
       AccumType->getScalarSizeInBits() / InputTypeA->getScalarSizeInBits();
   if (VF.getKnownMinValue() <= Ratio)
     return Invalid;
+  // i32 -> i64 or i16 -> i32 is not natively supported on Neon and SVE.
+  if (Ratio < 4)
+    return Invalid;
 
   VectorType *InputVectorType = VectorType::get(InputTypeA, VF);
   VectorType *AccumVectorType =
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce.ll
index 1e6bcb12029e7..0a61e8e9a505f 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce.ll
@@ -1117,24 +1117,24 @@ define i64 @sext_reduction_i32_to_i64(ptr %arr, i64 %n) #1 {
 ; CHECK-INTERLEAVE1-SAME: ptr [[ARR:%.*]], i64 [[N:%.*]]) #[[ATTR2]] {
 ; CHECK-INTERLEAVE1-NEXT:  entry:
 ; CHECK-INTERLEAVE1-NEXT:    [[UMAX:%.*]] = call i64 @llvm.umax.i64(i64 [[N]], i64 1)
-; CHECK-INTERLEAVE1-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[UMAX]], 2
+; CHECK-INTERLEAVE1-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[UMAX]], 4
 ; CHECK-INTERLEAVE1-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
 ; CHECK-INTERLEAVE1:       vector.ph:
-; CHECK-INTERLEAVE1-NEXT:    [[N_MOD_VF:%.*]] = urem i64 [[UMAX]], 2
+; CHECK-INTERLEAVE1-NEXT:    [[N_MOD_VF:%.*]] = urem i64 [[UMAX]], 4
 ; CHECK-INTERLEAVE1-NEXT:    [[N_VEC:%.*]] = sub i64 [[UMAX]], [[N_MOD_VF]]
 ; CHECK-INTERLEAVE1-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; CHECK-INTERLEAVE1:       vector.body:
 ; CHECK-INTERLEAVE1-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-INTERLEAVE1-NEXT:    [[VEC_PHI:%.*]] = phi <2 x i64> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP2:%.*]], [[VECTOR_BODY]] ]
+; CHECK-INTERLEAVE1-NEXT:    [[VEC_PHI:%.*]] = phi <4 x i64> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP2:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-INTERLEAVE1-NEXT:    [[TMP4:%.*]] = getelementptr inbounds i32, ptr [[ARR]], i64 [[INDEX]]
-; CHECK-INTERLEAVE1-NEXT:    [[WIDE_LOAD:%.*]] = load <2 x i32>, ptr [[TMP4]], align 4
-; CHECK-INTERLEAVE1-NEXT:    [[TMP1:%.*]] = sext <2 x i32> [[WIDE_LOAD]] to <2 x i64>
-; CHECK-INTERLEAVE1-NEXT:    [[TMP2]] = add <2 x i64> [[VEC_PHI]], [[TMP1]]
-; CHECK-INTERLEAVE1-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 2
+; CHECK-INTERLEAVE1-NEXT:    [[WIDE_LOAD:%.*]] = load <4 x i32>, ptr [[TMP4]], align 4
+; CHECK-INTERLEAVE1-NEXT:    [[TMP1:%.*]] = sext <4 x i32> [[WIDE_LOAD]] to <4 x i64>
+; CHECK-INTERLEAVE1-NEXT:    [[TMP2]] = add <4 x i64> [[VEC_PHI]], [[TMP1]]
+; CHECK-INTERLEAVE1-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4
 ; CHECK-INTERLEAVE1-NEXT:    [[TMP7:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
 ; CHECK-INTERLEAVE1-NEXT:    br i1 [[TMP7]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP21:![0-9]+]]
 ; CHECK-INTERLEAVE1:       middle.block:
-; CHECK-INTERLEAVE1-NEXT:    [[TMP5:%.*]] = call i64 @llvm.vector.reduce.add.v2i64(<2 x i64> [[TMP2]])
+; CHECK-INTERLEAVE1-NEXT:    [[TMP5:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP2]])
 ; CHECK-INTERLEAVE1-NEXT:    [[CMP_N:%.*]] = icmp eq i64 [[UMAX]], [[N_VEC]]
 ; CHECK-INTERLEAVE1-NEXT:    br i1 [[CMP_N]], label [[EXIT:%.*]], label [[SCALAR_PH]]
 ; CHECK-INTERLEAVE1:       scalar.ph:
@@ -1143,42 +1143,42 @@ define i64 @sext_reduction_i32_to_i64(ptr %arr, i64 %n) #1 {
 ; CHECK-INTERLEAVED-SAME: ptr [[ARR:%.*]], i64 [[N:%.*]]) #[[ATTR2]] {
 ; CHECK-INTERLEAVED-NEXT:  entry:
 ; CHECK-INTERLEAVED-NEXT:    [[UMAX:%.*]] = call i64 @llvm.umax.i64(i64 [[N]], i64 1)
-; CHECK-INTERLEAVED-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[UMAX]], 8
+; CHECK-INTERLEAVED-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[UMAX]], 16
 ; CHECK-INTERLEAVED-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
 ; CHECK-INTERLEAVED:       vector.ph:
-; CHECK-INTERLEAVED-NEXT:    [[N_MOD_VF:%.*]] = urem i64 [[UMAX]], 8
+; CHECK-INTERLEAVED-NEXT:    [[N_MOD_VF:%.*]] = urem i64 [[UMAX]], 16
 ; CHECK-INTERLEAVED-NEXT:    [[N_VEC:%.*]] = sub i64 [[UMAX]], [[N_MOD_VF]]
 ; CHECK-INTERLEAVED-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; CHECK-INTERLEAVED:       vector.body:
 ; CHECK-INTERLEAVED-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-INTERLEAVED-NEXT:    [[VEC_PHI:%.*]] = phi <2 x i64> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP8:%.*]], [[VECTOR_BODY]] ]
-; CHECK-INTERLEAVED-NEXT:    [[VEC_PHI1:%.*]] = phi <2 x i64> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP9:%.*]], [[VECTOR_BODY]] ]
-; CHECK-INTERLEAVED-NEXT:    [[VEC_PHI2:%.*]] = phi <2 x i64> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP10:%.*]], [[VECTOR_BODY]] ]
-; CHECK-INTERLEAVED-NEXT:    [[VEC_PHI3:%.*]] = phi <2 x i64> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP11:%.*]], [[VECTOR_BODY]] ]
+; CHECK-INTERLEAVED-NEXT:    [[VEC_PHI:%.*]] = phi <4 x i64> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP8:%.*]], [[VECTOR_BODY]] ]
+; CHECK-INTERLEAVED-NEXT:    [[VEC_PHI1:%.*]] = phi <4 x i64> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP9:%.*]], [[VECTOR_BODY]] ]
+; CHECK-INTERLEAVED-NEXT:    [[VEC_PHI2:%.*]] = phi <4 x i64> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP10:%.*]], [[VECTOR_BODY]] ]
+; CHECK-INTERLEAVED-NEXT:    [[VEC_PHI3:%.*]] = phi <4 x i64> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP11:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-INTERLEAVED-NEXT:    [[TMP4:%.*]] = getelementptr inbounds i32, ptr [[ARR]], i64 [[INDEX]]
-; CHECK-INTERLEAVED-NEXT:    [[TMP1:%.*]] = getelementptr inbounds i32, ptr [[TMP4]], i32 2
 ; CHECK-INTERLEAVED-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i32, ptr [[TMP4]], i32 4
-; CHECK-INTERLEAVED-NEXT:    [[TMP3:%.*]] = getelementptr inbounds i32, ptr [[TMP4]], i32 6
-; CHECK-INTERLEAVED-NEXT:    [[WIDE_LOAD:%.*]] = load <2 x i32>, ptr [[TMP4]], align 4
-; CHECK-INTERLEAVED-NEXT:    [[WIDE_LOAD4:%.*]] = load <2 x i32>, ptr [[TMP1]], align 4
-; CHECK-INTERLEAVED-NEXT:    [[WIDE_LOAD5:%.*]] = load <2 x i32>, ptr [[TMP2]], align 4
-; CHECK-INTERLEAVED-NEXT:    [[WIDE_LOAD6:%.*]] = load <2 x i32>, ptr [[TMP3]], align 4
-; CHECK-INTERLEAVED-NEXT:    [[TMP14:%.*]] = sext <2 x i32> [[WIDE_LOAD]] to <2 x i64>
-; CHECK-INTERLEAVED-NEXT:    [[TMP5:%.*]] = sext <2 x i32> [[WIDE_LOAD4]] to <2 x i64>
-; CHECK-INTERLEAVED-NEXT:    [[TMP6:%.*]] = sext <2 x i32> [[WIDE_LOAD5]] to <2 x i64>
-; CHECK-INTERLEAVED-NEXT:    [[TMP7:%.*]] = sext <2 x i32> [[WIDE_LOAD6]] to <2 x i64>
-; CHECK-INTERLEAVED-NEXT:    [[TMP8]] = add <2 x i64> [[VEC_PHI]], [[TMP14]]
-; CHECK-INTERLEAVED-NEXT:    [[TMP9]] = add <2 x i64> [[VEC_PHI1]], [[TMP5]]
-; CHECK-INTERLEAVED-NEXT:    [[TMP10]] = add <2 x i64> [[VEC_PHI2]], [[TMP6]]
-; CHECK-INTERLEAVED-NEXT:    [[TMP11]] = add <2 x i64> [[VEC_PHI3]], [[TMP7]]
-; CHECK-INTERLEAVED-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 8
+; CHECK-INTERLEAVED-NEXT:    [[TMP14:%.*]] = getelementptr inbounds i32, ptr [[TMP4]], i32 8
+; CHECK-INTERLEAVED-NEXT:    [[TMP3:%.*]] = getelementptr inbounds i32, ptr [[TMP4]], i32 12
+; CHECK-INTERLEAVED-NEXT:    [[WIDE_LOAD:%.*]] = load <4 x i32>, ptr [[TMP4]], align 4
+; CHECK-INTERLEAVED-NEXT:    [[WIDE_LOAD4:%.*]] = load <4 x i32>, ptr [[TMP2]], align 4
+; CHECK-INTERLEAVED-NEXT:    [[WIDE_LOAD5:%.*]] = load <4 x i32>, ptr [[TMP14]], align 4
+; CHECK-INTERLEAVED-NEXT:    [[WIDE_LOAD6:%.*]] = load <4 x i32>, ptr [[TMP3]], align 4
+; CHECK-INTERLEAVED-NEXT:    [[TMP15:%.*]] = sext <4 x i32> [[WIDE_LOAD]] to <4 x i64>
+; CHECK-INTERLEAVED-NEXT:    [[TMP5:%.*]] = sext <4 x i32> [[WIDE_LOAD4]] to <4 x i64>
+; CHECK-INTERLEAVED-NEXT:    [[TMP6:%.*]] = sext <4 x i32> [[WIDE_LOAD5]] to <4 x i64>
+; CHECK-INTERLEAVED-NEXT:    [[TMP7:%.*]] = sext <4 x i32> [[WIDE_LOAD6]] to <4 x i64>
+; CHECK-INTERLEAVED-NEXT:    [[TMP8]] = add <4 x i64> [[VEC_PHI]], [[TMP15]]
+; CHECK-INTERLEAVED-NEXT:    [[TMP9]] = add <4 x i64> [[VEC_PHI1]], [[TMP5]]
+; CHECK-INTERLEAVED-NEXT:    [[TMP10]] = add <4 x i64> [[VEC_PHI2]], [[TMP6]]
+; CHECK-INTERLEAVED-NEXT:    [[TMP11]] = add <4 x i64> [[VEC_PHI3]], [[TMP7]]
+; CHECK-INTERLEAVED-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
 ; CHECK-INTERLEAVED-NEXT:    [[TMP12:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
 ; CHECK-INTERLEAVED-NEXT:    br i1 [[TMP12]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP21:![0-9]+]]
 ; CHECK-INTERLEAVED:       middle.block:
-; CHECK-INTERLEAVED-NEXT:    [[BIN_RDX:%.*]] = add <2 x i64> [[TMP9]], [[TMP8]]
-; CHECK-INTERLEAVED-NEXT:    [[BIN_RDX7:%.*]] = add <2 x i64> [[TMP10]], [[BIN_RDX]]
-; CHECK-INTERLEAVED-NEXT:    [[BIN_RDX8:%.*]] = add <2 x i64> [[TMP11]], [[BIN_RDX7]]
-; CHECK-INTERLEAVED-NEXT:    [[TMP13:%.*]] = call i64 @llvm.vector.reduce.add.v2i64(<2 x i64> [[BIN_RDX8]])
+; CHECK-INTERLEAVED-NEXT:    [[BIN_RDX:%.*]] = add <4 x i64> [[TMP9]], [[TMP8]]
+; CHECK-INTERLEAVED-NEXT:    [[BIN_RDX7:%.*]] = add <4 x i64> [[TMP10]], [[BIN_RDX]]
+; CHECK-INTERLEAVED-NEXT:    [[BIN_RDX8:%.*]] = add <4 x i64> [[TMP11]], [[BIN_RDX7]]
+; CHECK-INTERLEAVED-NEXT:    [[TMP13:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[BIN_RDX8]])
 ; CHECK-INTERLEAVED-NEXT:    [[CMP_N:%.*]] = icmp eq i64 [[UMAX]], [[N_VEC]]
 ; CHECK-INTERLEAVED-NEXT:    br i1 [[CMP_N]], label [[EXIT:%.*]], label [[SCALAR_PH]]
 ; CHECK-INTERLEAVED:       scalar.ph:
@@ -1187,24 +1187,24 @@ define i64 @sext_reduction_i32_to_i64(ptr %arr, i64 %n) #1 {
 ; CHECK-MAXBW-SAME: ptr [[ARR:%.*]], i64 [[N:%.*]]) #[[ATTR2]] {
 ; CHECK-MAXBW-NEXT:  entry:
 ; CHECK-MAXBW-NEXT:    [[UMAX:%.*]] = call i64 @llvm.umax.i64(i64 [[N]], i64 1)
-; CHECK-MAXBW-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[UMAX]], 2
+; CHECK-MAXBW-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[UMAX]], 4
 ; CHECK-MAXBW-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
 ; CHECK-MAXBW:       vector.ph:
-; CHECK-MAXBW-NEXT:    [[N_MOD_VF:%.*]] = urem i64 [[UMAX]], 2
+; CHECK-MAXBW-NEXT:    [[N_MOD_VF:%.*]] = urem i64 [[UMAX]], 4
 ; CHECK-MAXBW-NEXT:    [[N_VEC:%.*]] = sub i64 [[UMAX]], [[N_MOD_VF]]
 ; CHECK-MAXBW-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; CHECK-MAXBW:       vector.body:
 ; CHECK-MAXBW-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-MAXBW-NEXT:    [[VEC_PHI:%.*]] = phi <2 x i64> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP2:%.*]], [[VECTOR_BODY]] ]
+; CHECK-MAXBW-NEXT:    [[VEC_PHI:%.*]] = phi <4 x i64> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP2:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-MAXBW-NEXT:    [[TMP4:%.*]] = getelementptr inbounds i32, ptr [[ARR]], i64 [[INDEX]]
-; CHECK-MAXBW-NEXT:    [[WIDE_LOAD:%.*]] = load <2 x i32>, ptr [[TMP4]], align 4
-; CHECK-MAXBW-NEXT:    [[TMP1:%.*]] = sext <2 x i32> [[WIDE_LOAD]] to <2 x i64>
-; CHECK-MAXBW-NEXT:    [[TMP2]] = add <2 x i64> [[VEC_PHI]], [[TMP1]]
-; CHECK-MAXBW-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 2
+; CHECK-MAXBW-NEXT:    [[WIDE_LOAD:%.*]] = load <4 x i32>, ptr [[TMP4]], align 4
+; CHECK-MAXBW-NEXT:    [[TMP1:%.*]] = sext <4 x i32> [[WIDE_LOAD]] to <4 x i64>
+; CHECK-MAXBW-NEXT:    [[TMP2]] = add <4 x i64> [[VEC_PHI]], [[TMP1]]
+; CHECK-MAXBW-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4
 ; CHECK-MAXBW-NEXT:    [[TMP7:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
 ; CHECK-MAXBW-NEXT:    br i1 [[TMP7]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP21:![0-9]+]]
 ; CHECK-MAXBW:       middle.block:
-; CHECK-MAXBW-NEXT:    [[TMP5:%.*]] = call i64 @llvm.vector.reduce.add.v2i64(<2 x i64> [[TMP2]])
+; CHECK-MAXBW-NEXT:    [[TMP5:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP2]])
 ; CHECK-MAXBW-NEXT:    [[CMP_N:%.*]] = icmp eq i64 [[UMAX]], [[N_VEC]]
 ; CHECK-MAXBW-NEXT:    br i1 [[CMP_N]], label [[EXIT:%.*]], label [[SCALAR_PH]]
 ; CHECK-MAXBW:       scalar.ph:

AccumType->getScalarSizeInBits() / InputTypeA->getScalarSizeInBits();
if (VF.getKnownMinValue() <= Ratio)
return Invalid;
// i32 -> i64 or i16 -> i32 is not natively supported on Neon and SVE.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't true because SVE2.1 has support for udot z0.s, z1.h, z2.h

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, we can in theory lower code for a ratio of 2 even for NEON or SVE. It just might not be optimal codegen. Perhaps the cost modelling below needs modifying to handle the case you are worried about?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't true because SVE2.1 has support for udot z0.s, z1.h, z2.h

ah , missed this

Also, we can in theory lower code for a ratio of 2 even for NEON or SVE. It just might not be optimal codegen. Perhaps the cost modelling below needs modifying to handle the case you are worried about?

Will try if I can cost model the codegen here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

if (VF.getKnownMinValue() <= Ratio)
return Invalid;
// i32 -> i64 or i16 -> i32 is not natively supported on Neon and SVE.
if (Ratio < 4)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A partial reduction from i32 -> i64 or i16 -> i32 is supported with smlalb+smlalt. If the cost at the end of the function would be slightly tweaked to return Cost + 2; (instead of return Cost + 4), then this loop vectorises with partial reductions and VF=4.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @sdesmalen-arm. Have made some changes

This adds some additional cost modelling and corresponding test cases.
@sushgokh
Copy link
Contributor Author

@david-arm @sdesmalen-arm

  1. Have made cost model change that helps my test
  2. Added few more cost model changes and corresponding tests.

Comment on lines 5731 to 5732
if (ST->hasSVE2p1())
return Cost;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is implementing a new requirement that should be in a separate PR. Also, this should be added below line 5712 above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved.

// FIXME: This should be more expensive for NEON as we see fmov instructions
// with very low throughput.
// Add additional cost for the extends that would need to be inserted.
return Cost + 4;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My suggestion was merely to change the return Cost + 4; at the bottom of this function to return Cost + 2;.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines 5742 to 5745
// umull + umull2 + (2 * uaddw) + (2 * uaddw2). Same goes for signed types.
if (AccumLT.second.getScalarType() == MVT::i64 &&
InputLT.second.getScalarType() == MVT::i16)
return Cost + 5;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not correct for targets with SVE2, which would use sdot. I'd also like us to avoid having to encoded every combination to this level of detail. The function is set up to return a cheap cost for all legal cases, and a higher cost for all illegal cases (that would require extends).

For targets where SVE(2) is not available, perhaps you can just return a higher default at the bottom of this function. But as it stands, the suggestion I made below (to return Cost + 2) causes all the tests to pass, so there is currently missing test-coverage for the (NEON only) case you're trying to support.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved. Set it to Cost + 2

@sushgokh sushgokh merged commit 005ec78 into llvm:main Oct 21, 2025
9 of 10 checks passed
@sushgokh sushgokh deleted the features/sgokhale/add-partial-reduction-constraint branch October 21, 2025 00:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants