Skip to content

Commit d7d04d2

Browse files
committed
[GlobalISel] Add boolean predicated legalization action methods.
Under AArch64 it is common and will become more common to have operation legalization rules dependant on a feature of the architecture. For example HasFP16 or the newer CSSC interger min/max instructions, almong many others. With the current legalization rules this either means adding a custom predicate based on the feature as in `legalIf([=](const LegalityQuery &Query) { return HasFP16 && ...; }` or splitting the legalization rules into pieces that place rules optionally into them base on the features available. This patch proposes an alterative where the existing routines like legalFor(..) are provided a boolean predicate, which if false skips adding the rule. It makes the rules cleaner and will hopefully allow them to scale better as we add more features.
1 parent d2408c4 commit d7d04d2

File tree

3 files changed

+119
-165
lines changed

3 files changed

+119
-165
lines changed

llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -599,11 +599,22 @@ class LegalizeRuleSet {
599599
LegalizeRuleSet &legalFor(std::initializer_list<LLT> Types) {
600600
return actionFor(LegalizeAction::Legal, Types);
601601
}
602+
LegalizeRuleSet &legalFor(bool Pred, std::initializer_list<LLT> Types) {
603+
if (!Pred)
604+
return *this;
605+
return actionFor(LegalizeAction::Legal, Types);
606+
}
602607
/// The instruction is legal when type indexes 0 and 1 is any type pair in the
603608
/// given list.
604609
LegalizeRuleSet &legalFor(std::initializer_list<std::pair<LLT, LLT>> Types) {
605610
return actionFor(LegalizeAction::Legal, Types);
606611
}
612+
LegalizeRuleSet &legalFor(bool Pred,
613+
std::initializer_list<std::pair<LLT, LLT>> Types) {
614+
if (!Pred)
615+
return *this;
616+
return actionFor(LegalizeAction::Legal, Types);
617+
}
607618
/// The instruction is legal when type index 0 is any type in the given list
608619
/// and imm index 0 is anything.
609620
LegalizeRuleSet &legalForTypeWithAnyImm(std::initializer_list<LLT> Types) {
@@ -846,12 +857,23 @@ class LegalizeRuleSet {
846857
LegalizeRuleSet &customFor(std::initializer_list<LLT> Types) {
847858
return actionFor(LegalizeAction::Custom, Types);
848859
}
860+
LegalizeRuleSet &customFor(bool Pred, std::initializer_list<LLT> Types) {
861+
if (!Pred)
862+
return *this;
863+
return actionFor(LegalizeAction::Custom, Types);
864+
}
849865

850-
/// The instruction is custom when type indexes 0 and 1 is any type pair in the
851-
/// given list.
866+
/// The instruction is custom when type indexes 0 and 1 is any type pair in
867+
/// the given list.
852868
LegalizeRuleSet &customFor(std::initializer_list<std::pair<LLT, LLT>> Types) {
853869
return actionFor(LegalizeAction::Custom, Types);
854870
}
871+
LegalizeRuleSet &customFor(bool Pred,
872+
std::initializer_list<std::pair<LLT, LLT>> Types) {
873+
if (!Pred)
874+
return *this;
875+
return actionFor(LegalizeAction::Custom, Types);
876+
}
855877

856878
LegalizeRuleSet &customForCartesianProduct(std::initializer_list<LLT> Types) {
857879
return actionForCartesianProduct(LegalizeAction::Custom, Types);
@@ -990,6 +1012,11 @@ class LegalizeRuleSet {
9901012
scalarNarrowerThan(TypeIdx, Ty.getSizeInBits()),
9911013
changeTo(typeIdx(TypeIdx), Ty));
9921014
}
1015+
LegalizeRuleSet &minScalar(bool Pred, unsigned TypeIdx, const LLT Ty) {
1016+
if (!Pred)
1017+
return *this;
1018+
return minScalar(TypeIdx, Ty);
1019+
}
9931020

9941021
/// Ensure the scalar is at least as wide as Ty if condition is met.
9951022
LegalizeRuleSet &minScalarIf(LegalityPredicate Predicate, unsigned TypeIdx,

llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp

Lines changed: 54 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -215,19 +215,10 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
215215
.legalFor({s64, v8s16, v16s8, v4s32})
216216
.lower();
217217

218-
auto &MinMaxActions = getActionDefinitionsBuilder(
219-
{G_SMIN, G_SMAX, G_UMIN, G_UMAX});
220-
if (HasCSSC)
221-
MinMaxActions
222-
.legalFor({s32, s64, v8s8, v16s8, v4s16, v8s16, v2s32, v4s32})
223-
// Making clamping conditional on CSSC extension as without legal types we
224-
// lower to CMP which can fold one of the two sxtb's we'd otherwise need
225-
// if we detect a type smaller than 32-bit.
226-
.minScalar(0, s32);
227-
else
228-
MinMaxActions
229-
.legalFor({v8s8, v16s8, v4s16, v8s16, v2s32, v4s32});
230-
MinMaxActions
218+
getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX})
219+
.legalFor({v8s8, v16s8, v4s16, v8s16, v2s32, v4s32})
220+
.legalFor(HasCSSC, {s32, s64})
221+
.minScalar(HasCSSC, 0, s32)
231222
.clampNumElements(0, v8s8, v16s8)
232223
.clampNumElements(0, v4s16, v8s16)
233224
.clampNumElements(0, v2s32, v4s32)
@@ -247,11 +238,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
247238
{G_FADD, G_FSUB, G_FMUL, G_FDIV, G_FMA, G_FSQRT, G_FMAXNUM, G_FMINNUM,
248239
G_FMAXIMUM, G_FMINIMUM, G_FCEIL, G_FFLOOR, G_FRINT, G_FNEARBYINT,
249240
G_INTRINSIC_TRUNC, G_INTRINSIC_ROUND, G_INTRINSIC_ROUNDEVEN})
250-
.legalFor({MinFPScalar, s32, s64, v2s32, v4s32, v2s64})
251-
.legalIf([=](const LegalityQuery &Query) {
252-
const auto &Ty = Query.Types[0];
253-
return (Ty == v8s16 || Ty == v4s16) && HasFP16;
254-
})
241+
.legalFor({s32, s64, v2s32, v4s32, v2s64})
242+
.legalFor(HasFP16, {s16, v4s16, v8s16})
255243
.libcallFor({s128})
256244
.scalarizeIf(scalarOrEltWiderThan(0, 64), 0)
257245
.minScalarOrElt(0, MinFPScalar)
@@ -261,11 +249,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
261249
.moreElementsToNextPow2(0);
262250

263251
getActionDefinitionsBuilder({G_FABS, G_FNEG})
264-
.legalFor({MinFPScalar, s32, s64, v2s32, v4s32, v2s64})
265-
.legalIf([=](const LegalityQuery &Query) {
266-
const auto &Ty = Query.Types[0];
267-
return (Ty == v8s16 || Ty == v4s16) && HasFP16;
268-
})
252+
.legalFor({s32, s64, v2s32, v4s32, v2s64})
253+
.legalFor(HasFP16, {s16, v4s16, v8s16})
269254
.scalarizeIf(scalarOrEltWiderThan(0, 64), 0)
270255
.lowerIf(scalarOrEltWiderThan(0, 64))
271256
.clampNumElements(0, v4s16, v8s16)
@@ -350,31 +335,7 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
350335
return ValTy.isPointerVector() && ValTy.getAddressSpace() == 0;
351336
};
352337

353-
auto &LoadActions = getActionDefinitionsBuilder(G_LOAD);
354-
auto &StoreActions = getActionDefinitionsBuilder(G_STORE);
355-
356-
if (ST.hasSVE()) {
357-
LoadActions.legalForTypesWithMemDesc({
358-
// 128 bit base sizes
359-
{nxv16s8, p0, nxv16s8, 8},
360-
{nxv8s16, p0, nxv8s16, 8},
361-
{nxv4s32, p0, nxv4s32, 8},
362-
{nxv2s64, p0, nxv2s64, 8},
363-
});
364-
365-
// TODO: Add nxv2p0. Consider bitcastIf.
366-
// See #92130
367-
// https://github.com/llvm/llvm-project/pull/92130#discussion_r1616888461
368-
StoreActions.legalForTypesWithMemDesc({
369-
// 128 bit base sizes
370-
{nxv16s8, p0, nxv16s8, 8},
371-
{nxv8s16, p0, nxv8s16, 8},
372-
{nxv4s32, p0, nxv4s32, 8},
373-
{nxv2s64, p0, nxv2s64, 8},
374-
});
375-
}
376-
377-
LoadActions
338+
getActionDefinitionsBuilder(G_LOAD)
378339
.customIf([=](const LegalityQuery &Query) {
379340
return HasRCPC3 && Query.Types[0] == s128 &&
380341
Query.MMODescrs[0].Ordering == AtomicOrdering::Acquire;
@@ -399,6 +360,13 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
399360
// These extends are also legal
400361
.legalForTypesWithMemDesc(
401362
{{s32, p0, s8, 8}, {s32, p0, s16, 8}, {s64, p0, s32, 8}})
363+
.legalForTypesWithMemDesc({
364+
// SVE vscale x 128 bit base sizes
365+
{nxv16s8, p0, nxv16s8, 8},
366+
{nxv8s16, p0, nxv8s16, 8},
367+
{nxv4s32, p0, nxv4s32, 8},
368+
{nxv2s64, p0, nxv2s64, 8},
369+
})
402370
.widenScalarToNextPow2(0, /* MinSize = */ 8)
403371
.clampMaxNumElements(0, s8, 16)
404372
.clampMaxNumElements(0, s16, 8)
@@ -424,7 +392,7 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
424392
.customIf(IsPtrVecPred)
425393
.scalarizeIf(typeInSet(0, {v2s16, v2s8}), 0);
426394

427-
StoreActions
395+
getActionDefinitionsBuilder(G_STORE)
428396
.customIf([=](const LegalityQuery &Query) {
429397
return HasRCPC3 && Query.Types[0] == s128 &&
430398
Query.MMODescrs[0].Ordering == AtomicOrdering::Release;
@@ -444,6 +412,16 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
444412
{p0, p0, s64, 8}, {s128, p0, s128, 8}, {v16s8, p0, s128, 8},
445413
{v8s8, p0, s64, 8}, {v4s16, p0, s64, 8}, {v8s16, p0, s128, 8},
446414
{v2s32, p0, s64, 8}, {v4s32, p0, s128, 8}, {v2s64, p0, s128, 8}})
415+
.legalForTypesWithMemDesc({
416+
// SVE vscale x 128 bit base sizes
417+
// TODO: Add nxv2p0. Consider bitcastIf.
418+
// See #92130
419+
// https://github.com/llvm/llvm-project/pull/92130#discussion_r1616888461
420+
{nxv16s8, p0, nxv16s8, 8},
421+
{nxv8s16, p0, nxv8s16, 8},
422+
{nxv4s32, p0, nxv4s32, 8},
423+
{nxv2s64, p0, nxv2s64, 8},
424+
})
447425
.clampScalar(0, s8, s64)
448426
.lowerIf([=](const LegalityQuery &Query) {
449427
return Query.Types[0].isScalar() &&
@@ -530,12 +508,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
530508
.widenScalarToNextPow2(0)
531509
.clampScalar(0, s8, s64);
532510
getActionDefinitionsBuilder(G_FCONSTANT)
533-
.legalIf([=](const LegalityQuery &Query) {
534-
const auto &Ty = Query.Types[0];
535-
if (HasFP16 && Ty == s16)
536-
return true;
537-
return Ty == s32 || Ty == s64 || Ty == s128;
538-
})
511+
.legalFor({s32, s64, s128})
512+
.legalFor(HasFP16, {s16})
539513
.clampScalar(0, MinFPScalar, s128);
540514

541515
// FIXME: fix moreElementsToNextPow2
@@ -567,16 +541,12 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
567541
.customIf(isVector(0));
568542

569543
getActionDefinitionsBuilder(G_FCMP)
570-
.legalFor({{s32, MinFPScalar},
571-
{s32, s32},
544+
.legalFor({{s32, s32},
572545
{s32, s64},
573546
{v4s32, v4s32},
574547
{v2s32, v2s32},
575548
{v2s64, v2s64}})
576-
.legalIf([=](const LegalityQuery &Query) {
577-
const auto &Ty = Query.Types[1];
578-
return (Ty == v8s16 || Ty == v4s16) && Ty == Query.Types[0] && HasFP16;
579-
})
549+
.legalFor(HasFP16, {{s32, s16}, {v4s16, v4s16}, {v8s16, v8s16}})
580550
.widenScalarOrEltToNextPow2(1)
581551
.clampScalar(0, s32, s32)
582552
.minScalarOrElt(1, MinFPScalar)
@@ -691,13 +661,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
691661
{v2s64, v2s64},
692662
{v4s32, v4s32},
693663
{v2s32, v2s32}})
694-
.legalIf([=](const LegalityQuery &Query) {
695-
return HasFP16 &&
696-
(Query.Types[1] == s16 || Query.Types[1] == v4s16 ||
697-
Query.Types[1] == v8s16) &&
698-
(Query.Types[0] == s32 || Query.Types[0] == s64 ||
699-
Query.Types[0] == v4s16 || Query.Types[0] == v8s16);
700-
})
664+
.legalFor(HasFP16,
665+
{{s32, s16}, {s64, s16}, {v4s16, v4s16}, {v8s16, v8s16}})
701666
.scalarizeIf(scalarOrEltWiderThan(0, 64), 0)
702667
.scalarizeIf(scalarOrEltWiderThan(1, 64), 1)
703668
// The range of a fp16 value fits into an i17, so we can lower the width
@@ -739,13 +704,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
739704
{v2s64, v2s64},
740705
{v4s32, v4s32},
741706
{v2s32, v2s32}})
742-
.legalIf([=](const LegalityQuery &Query) {
743-
return HasFP16 &&
744-
(Query.Types[1] == s16 || Query.Types[1] == v4s16 ||
745-
Query.Types[1] == v8s16) &&
746-
(Query.Types[0] == s32 || Query.Types[0] == s64 ||
747-
Query.Types[0] == v4s16 || Query.Types[0] == v8s16);
748-
})
707+
.legalFor(HasFP16,
708+
{{s32, s16}, {s64, s16}, {v4s16, v4s16}, {v8s16, v8s16}})
749709
// Handle types larger than i64 by scalarizing/lowering.
750710
.scalarizeIf(scalarOrEltWiderThan(0, 64), 0)
751711
.scalarizeIf(scalarOrEltWiderThan(1, 64), 1)
@@ -788,13 +748,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
788748
{v2s64, v2s64},
789749
{v4s32, v4s32},
790750
{v2s32, v2s32}})
791-
.legalIf([=](const LegalityQuery &Query) {
792-
return HasFP16 &&
793-
(Query.Types[0] == s16 || Query.Types[0] == v4s16 ||
794-
Query.Types[0] == v8s16) &&
795-
(Query.Types[1] == s32 || Query.Types[1] == s64 ||
796-
Query.Types[1] == v4s16 || Query.Types[1] == v8s16);
797-
})
751+
.legalFor(HasFP16,
752+
{{s16, s32}, {s16, s64}, {v4s16, v4s16}, {v8s16, v8s16}})
798753
.scalarizeIf(scalarOrEltWiderThan(1, 64), 1)
799754
.scalarizeIf(scalarOrEltWiderThan(0, 64), 0)
800755
.moreElementsToNextPow2(1)
@@ -1048,12 +1003,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
10481003
.widenScalarToNextPow2(1, /*Min=*/32)
10491004
.clampScalar(1, s32, s64)
10501005
.scalarSameSizeAs(0, 1)
1051-
.legalIf([=](const LegalityQuery &Query) {
1052-
return (HasCSSC && typeInSet(0, {s32, s64})(Query));
1053-
})
1054-
.customIf([=](const LegalityQuery &Query) {
1055-
return (!HasCSSC && typeInSet(0, {s32, s64})(Query));
1056-
});
1006+
.legalFor(HasCSSC, {s32, s64})
1007+
.customFor(!HasCSSC, {s32, s64});
10571008

10581009
getActionDefinitionsBuilder(G_SHUFFLE_VECTOR)
10591010
.legalIf([=](const LegalityQuery &Query) {
@@ -1141,11 +1092,9 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
11411092
}
11421093

11431094
// FIXME: Legal vector types are only legal with NEON.
1144-
auto &ABSActions = getActionDefinitionsBuilder(G_ABS);
1145-
if (HasCSSC)
1146-
ABSActions
1147-
.legalFor({s32, s64});
1148-
ABSActions.legalFor(PackedVectorAllTypeList)
1095+
getActionDefinitionsBuilder(G_ABS)
1096+
.legalFor(HasCSSC, {s32, s64})
1097+
.legalFor(PackedVectorAllTypeList)
11491098
.customIf([=](const LegalityQuery &Q) {
11501099
// TODO: Fix suboptimal codegen for 128+ bit types.
11511100
LLT SrcTy = Q.Types[0];
@@ -1169,10 +1118,7 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
11691118
// later.
11701119
getActionDefinitionsBuilder(G_VECREDUCE_FADD)
11711120
.legalFor({{s32, v2s32}, {s32, v4s32}, {s64, v2s64}})
1172-
.legalIf([=](const LegalityQuery &Query) {
1173-
const auto &Ty = Query.Types[1];
1174-
return (Ty == v4s16 || Ty == v8s16) && HasFP16;
1175-
})
1121+
.legalFor(HasFP16, {{s16, v4s16}, {s16, v8s16}})
11761122
.minScalarOrElt(0, MinFPScalar)
11771123
.clampMaxNumElements(1, s64, 2)
11781124
.clampMaxNumElements(1, s32, 4)
@@ -1213,10 +1159,7 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
12131159
getActionDefinitionsBuilder({G_VECREDUCE_FMIN, G_VECREDUCE_FMAX,
12141160
G_VECREDUCE_FMINIMUM, G_VECREDUCE_FMAXIMUM})
12151161
.legalFor({{s32, v4s32}, {s32, v2s32}, {s64, v2s64}})
1216-
.legalIf([=](const LegalityQuery &Query) {
1217-
const auto &Ty = Query.Types[1];
1218-
return Query.Types[0] == s16 && (Ty == v8s16 || Ty == v4s16) && HasFP16;
1219-
})
1162+
.legalFor(HasFP16, {{s16, v4s16}, {s16, v8s16}})
12201163
.minScalarOrElt(0, MinFPScalar)
12211164
.clampMaxNumElements(1, s64, 2)
12221165
.clampMaxNumElements(1, s32, 4)
@@ -1293,32 +1236,16 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
12931236
.customFor({{s32, s32}, {s64, s64}});
12941237

12951238
auto always = [=](const LegalityQuery &Q) { return true; };
1296-
auto &CTPOPActions = getActionDefinitionsBuilder(G_CTPOP);
1297-
if (HasCSSC)
1298-
CTPOPActions
1299-
.legalFor({{s32, s32},
1300-
{s64, s64},
1301-
{v8s8, v8s8},
1302-
{v16s8, v16s8}})
1303-
.customFor({{s128, s128},
1304-
{v2s64, v2s64},
1305-
{v2s32, v2s32},
1306-
{v4s32, v4s32},
1307-
{v4s16, v4s16},
1308-
{v8s16, v8s16}});
1309-
else
1310-
CTPOPActions
1311-
.legalFor({{v8s8, v8s8},
1312-
{v16s8, v16s8}})
1313-
.customFor({{s32, s32},
1314-
{s64, s64},
1315-
{s128, s128},
1316-
{v2s64, v2s64},
1317-
{v2s32, v2s32},
1318-
{v4s32, v4s32},
1319-
{v4s16, v4s16},
1320-
{v8s16, v8s16}});
1321-
CTPOPActions
1239+
getActionDefinitionsBuilder(G_CTPOP)
1240+
.legalFor(HasCSSC, {{s32, s32}, {s64, s64}})
1241+
.legalFor({{v8s8, v8s8}, {v16s8, v16s8}})
1242+
.customFor(!HasCSSC, {{s32, s32}, {s64, s64}})
1243+
.customFor({{s128, s128},
1244+
{v2s64, v2s64},
1245+
{v2s32, v2s32},
1246+
{v4s32, v4s32},
1247+
{v4s16, v4s16},
1248+
{v8s16, v8s16}})
13221249
.clampScalar(0, s32, s128)
13231250
.widenScalarToNextPow2(0)
13241251
.minScalarEltSameAsIf(always, 1, 0)

0 commit comments

Comments
 (0)