-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[AArch64][GlobalISel] Improve lowering of vector fp16 fptrunc #163398
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 20 commits
737470e
ae3ef1e
43b1509
437caa3
6abe127
13fc5dc
0ceacd7
8b85744
ec102fc
411afc0
60b6da7
76a03d6
a5635b7
5f97537
3671057
39c3e04
0604176
3513809
74aa139
a1bf07a
60cbbc7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,6 +21,7 @@ | |
| #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" | ||
| #include "llvm/CodeGen/GlobalISel/Utils.h" | ||
| #include "llvm/CodeGen/MachineInstr.h" | ||
| #include "llvm/CodeGen/MachineInstrBuilder.h" | ||
| #include "llvm/CodeGen/MachineRegisterInfo.h" | ||
| #include "llvm/CodeGen/TargetOpcodes.h" | ||
| #include "llvm/IR/DerivedTypes.h" | ||
|
|
@@ -820,8 +821,17 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST) | |
| .legalFor( | ||
| {{s16, s32}, {s16, s64}, {s32, s64}, {v4s16, v4s32}, {v2s32, v2s64}}) | ||
| .libcallFor({{s16, s128}, {s32, s128}, {s64, s128}}) | ||
| .clampNumElements(0, v4s16, v4s16) | ||
| .clampNumElements(0, v2s32, v2s32) | ||
| .moreElementsToNextPow2(1) | ||
| .customIf([](const LegalityQuery &Q) { | ||
HolyMolyCowMan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| LLT DstTy = Q.Types[0]; | ||
| LLT SrcTy = Q.Types[1]; | ||
| return SrcTy.isFixedVector() && DstTy.isFixedVector() && | ||
| SrcTy.getScalarSizeInBits() == 64 && | ||
| DstTy.getScalarSizeInBits() == 16; | ||
| }) | ||
| // Clamp based on input | ||
| .clampNumElements(1, v4s32, v4s32) | ||
| .clampNumElements(1, v2s64, v2s64) | ||
| .scalarize(0); | ||
|
|
||
| getActionDefinitionsBuilder(G_FPEXT) | ||
|
|
@@ -1479,6 +1489,10 @@ bool AArch64LegalizerInfo::legalizeCustom( | |
| return legalizeICMP(MI, MRI, MIRBuilder); | ||
| case TargetOpcode::G_BITCAST: | ||
| return legalizeBitcast(MI, Helper); | ||
| case TargetOpcode::G_FPTRUNC: | ||
| // In order to vectorise f16 to f64 properly, we need to use f32 as an | ||
HolyMolyCowMan marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| // intermediary | ||
| return legalizeFptrunc(MI, MIRBuilder, MRI); | ||
| } | ||
|
|
||
| llvm_unreachable("expected switch to return"); | ||
|
|
@@ -2416,3 +2430,81 @@ bool AArch64LegalizerInfo::legalizePrefetch(MachineInstr &MI, | |
| MI.eraseFromParent(); | ||
| return true; | ||
| } | ||
|
|
||
| bool AArch64LegalizerInfo::legalizeFptrunc(MachineInstr &MI, | ||
| MachineIRBuilder &MIRBuilder, | ||
| MachineRegisterInfo &MRI) const { | ||
| auto [Dst, DstTy, Src, SrcTy] = MI.getFirst2RegLLTs(); | ||
| assert(SrcTy.isFixedVector() && isPowerOf2_32(SrcTy.getNumElements()) && | ||
| "Expected a power of 2 elements"); | ||
|
Comment on lines
+2438
to
+2439
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we make this work with a "multiple of 2", not a "power of 2"?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can do but currently the legalizer widens the fptrunc src to the next power of 2, meaning we can keep this simple if we only expect powers of 2. Otherwise, we might have to pad vectors so we can later concat them.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh - that fine for now then. We should go through at some point and check non-power2 vector types.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should I add a todo comment?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No thats fine, we need to go through all of them I think. |
||
|
|
||
| LLT s16 = LLT::scalar(16); | ||
| LLT s32 = LLT::scalar(32); | ||
| LLT s64 = LLT::scalar(64); | ||
| LLT v2s16 = LLT::fixed_vector(2, s16); | ||
| LLT v4s16 = LLT::fixed_vector(4, s16); | ||
| LLT v2s32 = LLT::fixed_vector(2, s32); | ||
| LLT v4s32 = LLT::fixed_vector(4, s32); | ||
| LLT v2s64 = LLT::fixed_vector(2, s64); | ||
|
|
||
| SmallVector<Register> RegsToUnmergeTo; | ||
| SmallVector<Register> TruncOddDstRegs; | ||
| SmallVector<Register> RegsToMerge; | ||
|
|
||
| unsigned ElemCount = SrcTy.getNumElements(); | ||
|
|
||
| // Find the biggest size chunks we can work with | ||
| int StepSize = ElemCount % 4 ? 2 : 4; | ||
|
|
||
| // If we have a power of 2 greater than 2, we need to first unmerge into | ||
| // enough pieces | ||
| if (ElemCount <= 2) | ||
| RegsToUnmergeTo.push_back(Src); | ||
| else { | ||
| for (unsigned i = 0; i < ElemCount / 2; ++i) { | ||
HolyMolyCowMan marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| RegsToUnmergeTo.push_back(MRI.createGenericVirtualRegister(v2s64)); | ||
| } | ||
|
|
||
| MIRBuilder.buildUnmerge(RegsToUnmergeTo, Src); | ||
| } | ||
|
|
||
| // Create all of the round-to-odd instructions and store them | ||
| for (auto SrcReg : RegsToUnmergeTo) { | ||
| Register Mid = | ||
| MIRBuilder.buildInstr(AArch64::G_FPTRUNC_ODD, {v2s32}, {SrcReg}) | ||
| .getReg(0); | ||
| TruncOddDstRegs.push_back(Mid); | ||
| } | ||
|
|
||
| // Truncate 4s32 to 4s16 if we can to reduce instruction count, otherwise | ||
| // truncate 2s32 to 2s16. | ||
| unsigned Index = 0; | ||
| for (unsigned LoopIter = 0; LoopIter < ElemCount / StepSize; ++LoopIter) { | ||
| if (StepSize == 4) { | ||
| Register ConcatDst = | ||
| MIRBuilder | ||
| .buildMergeLikeInstr( | ||
| {v4s32}, {TruncOddDstRegs[Index++], TruncOddDstRegs[Index++]}) | ||
| .getReg(0); | ||
|
|
||
| RegsToMerge.push_back( | ||
| MIRBuilder.buildFPTrunc(v4s16, ConcatDst).getReg(0)); | ||
| } else { | ||
| RegsToMerge.push_back( | ||
| MIRBuilder.buildFPTrunc(v2s16, TruncOddDstRegs[Index++]).getReg(0)); | ||
| } | ||
| } | ||
|
|
||
| // If there is only one register, replace the destination | ||
| if (RegsToMerge.size() == 1) { | ||
| MRI.replaceRegWith(Dst, RegsToMerge.pop_back_val()); | ||
| MI.eraseFromParent(); | ||
| return true; | ||
| } | ||
|
|
||
| // Merge the rest of the instructions & replace the register | ||
| Register Fin = MIRBuilder.buildMergeLikeInstr(DstTy, RegsToMerge).getReg(0); | ||
| MRI.replaceRegWith(Dst, Fin); | ||
| MI.eraseFromParent(); | ||
| return true; | ||
| } | ||
HolyMolyCowMan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Uh oh!
There was an error while loading. Please reload this page.