Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5721,6 +5721,38 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
return Cost;
}

// FIXME:
// 1. Do cost modelling for USDOT.
// 2. Refactor the whole code here.
if (ST->isSVEorStreamingSVEAvailable() && !IsUSDot) {
if (AccumLT.second.getScalarType() == MVT::i32 &&
InputLT.second.getScalarType() == MVT::i16) {
// i16 -> i32 is supported in SVE 2.1
if (ST->hasSVE2p1())
return Cost;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is implementing a new requirement that should be in a separate PR. Also, this should be added below line 5712 above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved.

// umlalt + umlalb. Same goes for signed types.
return Cost + 1;
}
if (AccumLT.second.getScalarType() == MVT::i64 &&
InputLT.second.getScalarType() == MVT::i32)
return Cost + 1;
}
if (AccumLT.second.isFixedLengthVector() && ST->isNeonAvailable() &&
ST->hasDotProd() && !IsUSDot) {
// umull + umull2 + (2 * uaddw) + (2 * uaddw2). Same goes for signed types.
if (AccumLT.second.getScalarType() == MVT::i64 &&
InputLT.second.getScalarType() == MVT::i16)
return Cost + 5;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not correct for targets with SVE2, which would use sdot. I'd also like us to avoid having to encoded every combination to this level of detail. The function is set up to return a cheap cost for all legal cases, and a higher cost for all illegal cases (that would require extends).

For targets where SVE(2) is not available, perhaps you can just return a higher default at the bottom of this function. But as it stands, the suggestion I made below (to return Cost + 2) causes all the tests to pass, so there is currently missing test-coverage for the (NEON only) case you're trying to support.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved. Set it to Cost + 2

// umlal + umlal2. Same goes for signed types.
if ((AccumLT.second.getScalarType() == MVT::i32 &&
InputLT.second.getScalarType() == MVT::i16) ||
(AccumLT.second.getScalarType() == MVT::i64 &&
InputLT.second.getScalarType() == MVT::i32))
return Cost + 1;
}

// FIXME: This should be more expensive for NEON as we see fmov instructions
// with very low throughput.
// Add additional cost for the extends that would need to be inserted.
return Cost + 4;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My suggestion was merely to change the return Cost + 4; at the bottom of this function to return Cost + 2;.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

}
Expand Down
Loading