Skip to content

Commit f3db0cb

Browse files
Reland "[RISCV] Refactor X60 scheduling model helper classes. NFC." (#152336)
This PR fixes the issue that caused an ub in PR #151472. The issue was a shl call taking a negative shift amount (posDiff). The result was never used, but tablegen would perform the calculation anyway. The fix was to replace the shl call with just multiplications with constants. Original PR description: This patch improves the helper classes in the SpacemiT-X60 vector scheduling model and will be used in follow-up PRs: There are now two functions to map LMUL to values: * ConstValueUntilLMULThenDoubleBase: returns BaseValue for LMUL values before startLMUL, Value for startLMUL, then doubles Value for each subsequent LMUL. Useful for cases where fractional LMULs have constant cycles, and integer LMULs double as they increase. * GetLMULValue: takes an ordered list of LMUL cycles and LMUL and returns the corresponding cycle. Useful for cases we can't easily cover with ConstValueUntilLMULThenDoubleBase. This PR also adds some useful simplified versions of ConstValueUntilLMULThenDoubleBase, e.g.: ConstValueUntilLMULThenDouble (when BaseValue == Value), or ConstOneUntilMF4ThenDouble (when cycles start to double after MF2)
1 parent 82f5bd6 commit f3db0cb

File tree

1 file changed

+93
-70
lines changed

1 file changed

+93
-70
lines changed

llvm/lib/Target/RISCV/RISCVSchedSpacemitX60.td

Lines changed: 93 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -13,78 +13,113 @@
1313
//
1414
//===----------------------------------------------------------------------===//
1515

16-
class SMX60IsWorstCaseMX<string mx, list<string> MxList> {
17-
string LLMUL = LargestLMUL<MxList>.r;
18-
bit c = !eq(mx, LLMUL);
19-
}
16+
//===----------------------------------------------------------------------===//
17+
// Helpers
18+
19+
// Maps LMUL string to corresponding value from the Values array
20+
// LMUL values map to array indices as follows:
21+
// MF8 -> Values[0], MF4 -> Values[1], MF2 -> Values[2], M1 -> Values[3],
22+
// M2 -> Values[4], M4 -> Values[5], M8 -> Values[6]
23+
// Shorter lists are allowed, e.g., widening instructions don't work on M8
24+
class GetLMULValue<list<int> Values, string LMUL> {
25+
defvar Index = !cond(
26+
!eq(LMUL, "MF8"): 0,
27+
!eq(LMUL, "MF4"): 1,
28+
!eq(LMUL, "MF2"): 2,
29+
!eq(LMUL, "M1"): 3,
30+
!eq(LMUL, "M2"): 4,
31+
!eq(LMUL, "M4"): 5,
32+
!eq(LMUL, "M8"): 6,
33+
);
2034

21-
class SMX60IsWorstCaseMXSEW<string mx, int sew, list<string> MxList, bit isF = 0> {
22-
string LLMUL = LargestLMUL<MxList>.r;
23-
int SSEW = SmallestSEW<mx, isF>.r;
24-
bit c = !and(!eq(mx, LLMUL), !eq(sew, SSEW));
35+
assert !lt(Index, !size(Values)),
36+
"Missing LMUL value for '" # LMUL # "'. " #
37+
"Expected at least " # !add(Index, 1) # " elements, but got " #
38+
!size(Values) # ".";
39+
40+
int c = Values[Index];
2541
}
2642

27-
defvar SMX60VLEN = 256;
28-
defvar SMX60DLEN = !div(SMX60VLEN, 2);
43+
// Returns BaseValue for LMUL values before startLMUL, Value for startLMUL,
44+
// then doubles Value for each subsequent LMUL
45+
// Example: ConstValueUntilLMULThenDoubleBase<"M1", 2, 4, "M8"> returns:
46+
// MF8->2, MF4->2, MF2->2, M1->4, M2->8, M4->16, M8->32
47+
// This is useful for modeling scheduling parameters that scale with LMUL.
48+
class ConstValueUntilLMULThenDoubleBase<string startLMUL, int BaseValue, int Value, string currentLMUL> {
49+
assert !le(BaseValue, Value), "BaseValue must be less-equal to Value";
50+
defvar startPos = GetLMULValue<[0, 1, 2, 3, 4, 5, 6], startLMUL>.c;
51+
defvar currentPos = GetLMULValue<[0, 1, 2, 3, 4, 5, 6], currentLMUL>.c;
2952

30-
class Get1248Latency<string mx> {
53+
// Calculate the difference in positions
54+
defvar posDiff = !sub(currentPos, startPos);
55+
56+
// Calculate Value * (2^posDiff)
3157
int c = !cond(
32-
!eq(mx, "M2") : 2,
33-
!eq(mx, "M4") : 4,
34-
!eq(mx, "M8") : 8,
35-
true: 1
58+
!eq(posDiff, 0) : Value,
59+
!eq(posDiff, 1) : !mul(Value, 2),
60+
!eq(posDiff, 2) : !mul(Value, 4),
61+
!eq(posDiff, 3) : !mul(Value, 8),
62+
!eq(posDiff, 4) : !mul(Value, 16),
63+
!eq(posDiff, 5) : !mul(Value, 32),
64+
!eq(posDiff, 6) : !mul(Value, 64),
65+
true : BaseValue
3666
);
3767
}
3868

39-
// Used for: logical opsz, shifts, sign ext, merge/move, FP sign/recip/convert, mask ops, slides
40-
class Get4816Latency<string mx> {
41-
int c = !cond(
42-
!eq(mx, "M4") : 8,
43-
!eq(mx, "M8") : 16,
44-
true: 4
45-
);
69+
// Same as the previous function but BaseValue == Value
70+
class ConstValueUntilLMULThenDouble<string startLMUL, int Value, string currentLMUL> {
71+
int c = ConstValueUntilLMULThenDoubleBase<startLMUL, Value, Value, currentLMUL>.c;
72+
}
73+
74+
// Returns MF8->1, MF4->1, MF2->2, M1->4, M2->8, M4->16, M8->32
75+
class ConstOneUntilMF4ThenDouble<string mx> {
76+
int c = ConstValueUntilLMULThenDouble<"MF4", 1, mx>.c;
77+
}
78+
79+
// Returns MF8->1, MF4->1, MF2->1, M1->2, M2->4, M4->8, M8->16
80+
class ConstOneUntilMF2ThenDouble<string mx> {
81+
int c = ConstValueUntilLMULThenDouble<"MF2", 1, mx>.c;
82+
}
83+
84+
// Returns MF8->1, MF4->1, MF2->1, M1->1, M2->2, M4->4, M8->8
85+
class ConstOneUntilM1ThenDouble<string mx> {
86+
int c = ConstValueUntilLMULThenDouble<"M1", 1, mx>.c;
4687
}
4788

89+
//===----------------------------------------------------------------------===//
90+
// Latency helper classes
91+
4892
// Used for: arithmetic (add/sub/min/max), saturating/averaging, FP add/sub/min/max
49-
class Get458Latency<string mx> {
50-
int c = !cond(
51-
!eq(mx, "M4") : 5,
52-
!eq(mx, "M8") : 8,
53-
true: 4
54-
);
93+
class Get4458Latency<string mx> {
94+
int c = GetLMULValue<[/*MF8=*/4, /*MF4=*/4, /*MF2=*/4, /*M1=*/4, /*M2=*/4, /*M4=*/5, /*M8=*/8], mx>.c;
5595
}
5696

57-
// Widening scaling pattern (4,4,4,4,5,8,8): plateaus at higher LMULs
58-
// Used for: widening operations
97+
// Used for: widening operations (no M8)
5998
class Get4588Latency<string mx> {
60-
int c = !cond(
61-
!eq(mx, "M2") : 5,
62-
!eq(mx, "M4") : 8,
63-
!eq(mx, "M8") : 8, // M8 not supported for most widening, fallback
64-
true: 4
65-
);
99+
int c = GetLMULValue<[/*MF8=*/4, /*MF4=*/4, /*MF2=*/4, /*M1=*/4, /*M2=*/5, /*M4=*/8], mx>.c;
66100
}
67101

68102
// Used for: mask-producing comparisons, carry ops with mask, FP comparisons
69103
class Get461018Latency<string mx> {
70-
int c = !cond(
71-
!eq(mx, "M2") : 6,
72-
!eq(mx, "M4") : 10,
73-
!eq(mx, "M8") : 18,
74-
true: 4
75-
);
104+
int c = GetLMULValue<[/*MF8=*/4, /*MF4=*/4, /*MF2=*/4, /*M1=*/4, /*M2=*/6, /*M4=*/10, /*M8=*/18], mx>.c;
76105
}
77106

78-
// Used for: e64 multiply pattern, complex ops
79-
class Get781632Latency<string mx> {
80-
int c = !cond(
81-
!eq(mx, "M2") : 8,
82-
!eq(mx, "M4") : 16,
83-
!eq(mx, "M8") : 32,
84-
true: 7
85-
);
107+
//===----------------------------------------------------------------------===//
108+
109+
class SMX60IsWorstCaseMX<string mx, list<string> MxList> {
110+
string LLMUL = LargestLMUL<MxList>.r;
111+
bit c = !eq(mx, LLMUL);
86112
}
87113

114+
class SMX60IsWorstCaseMXSEW<string mx, int sew, list<string> MxList, bit isF = 0> {
115+
string LLMUL = LargestLMUL<MxList>.r;
116+
int SSEW = SmallestSEW<mx, isF>.r;
117+
bit c = !and(!eq(mx, LLMUL), !eq(sew, SSEW));
118+
}
119+
120+
defvar SMX60VLEN = 256;
121+
defvar SMX60DLEN = !div(SMX60VLEN, 2);
122+
88123
def SpacemitX60Model : SchedMachineModel {
89124
let IssueWidth = 2; // dual-issue
90125
let MicroOpBufferSize = 0; // in-order
@@ -383,12 +418,13 @@ foreach LMul = [1, 2, 4, 8] in {
383418
foreach mx = SchedMxList in {
384419
defvar IsWorstCase = SMX60IsWorstCaseMX<mx, SchedMxList>.c;
385420

386-
let Latency = Get458Latency<mx>.c, ReleaseAtCycles = [4] in {
421+
let Latency = Get4458Latency<mx>.c, ReleaseAtCycles = [4] in {
387422
defm "" : LMULWriteResMX<"WriteVIMinMaxV", [SMX60_VIEU], mx, IsWorstCase>;
388423
defm "" : LMULWriteResMX<"WriteVIMinMaxX", [SMX60_VIEU], mx, IsWorstCase>;
389424
}
390425

391-
let Latency = Get4816Latency<mx>.c, ReleaseAtCycles = [4] in {
426+
defvar VIALULat = ConstValueUntilLMULThenDouble<"M2", 4, mx>.c;
427+
let Latency = VIALULat, ReleaseAtCycles = [4] in {
392428
// Pattern of vadd, vsub, vrsub: 4/4/5/8
393429
// Pattern of vand, vor, vxor: 4/4/8/16
394430
// They are grouped together, so we used the worst case 4/4/8/16
@@ -425,7 +461,7 @@ foreach mx = SchedMxList in {
425461
// Pattern of vmacc, vmadd, vmul, vmulh, etc.: e8/e16 = 4/4/5/8, e32 = 5,5,5,8,
426462
// e64 = 7,8,16,32. We use the worst-case until we can split the SEW.
427463
// TODO: change WriteVIMulV, etc to be defined with LMULSEWSchedWrites
428-
let Latency = Get781632Latency<mx>.c, ReleaseAtCycles = [7] in {
464+
let Latency = ConstValueUntilLMULThenDoubleBase<"M2", 7, 8, mx>.c, ReleaseAtCycles = [7] in {
429465
defm "" : LMULWriteResMX<"WriteVIMulV", [SMX60_VIEU], mx, IsWorstCase>;
430466
defm "" : LMULWriteResMX<"WriteVIMulX", [SMX60_VIEU], mx, IsWorstCase>;
431467
defm "" : LMULWriteResMX<"WriteVIMulAddV", [SMX60_VIEU], mx, IsWorstCase>;
@@ -461,15 +497,8 @@ foreach mx = SchedMxList in {
461497
foreach sew = SchedSEWSet<mx>.val in {
462498
defvar IsWorstCase = SMX60IsWorstCaseMXSEW<mx, sew, SchedMxList>.c;
463499

464-
// Slightly reduced for fractional LMULs
465-
defvar Multiplier = !cond(
466-
!eq(mx, "MF8") : 12,
467-
!eq(mx, "MF4") : 12,
468-
!eq(mx, "MF2") : 12,
469-
true: 24
470-
);
471-
472-
let Latency = !mul(Get1248Latency<mx>.c, Multiplier), ReleaseAtCycles = [12] in {
500+
defvar VIDivLat = ConstValueUntilLMULThenDouble<"MF2", 12, mx>.c;
501+
let Latency = VIDivLat, ReleaseAtCycles = [12] in {
473502
defm "" : LMULSEWWriteResMXSEW<"WriteVIDivV", [SMX60_VIEU], mx, sew, IsWorstCase>;
474503
defm "" : LMULSEWWriteResMXSEW<"WriteVIDivX", [SMX60_VIEU], mx, sew, IsWorstCase>;
475504
}
@@ -480,14 +509,8 @@ foreach mx = SchedMxList in {
480509
foreach mx = SchedMxListW in {
481510
defvar IsWorstCase = SMX60IsWorstCaseMX<mx, SchedMxListW>.c;
482511

483-
// Slightly increased for integer LMULs
484-
defvar Multiplier = !cond(
485-
!eq(mx, "M2") : 2,
486-
!eq(mx, "M4") : 2,
487-
true: 1
488-
);
489-
490-
let Latency = !mul(Get4816Latency<mx>.c, Multiplier), ReleaseAtCycles = [4] in {
512+
defvar VNarrowingLat = ConstValueUntilLMULThenDouble<"M1", 4, mx>.c;
513+
let Latency = VNarrowingLat, ReleaseAtCycles = [4] in {
491514
defm "" : LMULWriteResMX<"WriteVNShiftV", [SMX60_VIEU], mx, IsWorstCase>;
492515
defm "" : LMULWriteResMX<"WriteVNShiftX", [SMX60_VIEU], mx, IsWorstCase>;
493516
defm "" : LMULWriteResMX<"WriteVNShiftI", [SMX60_VIEU], mx, IsWorstCase>;

0 commit comments

Comments
 (0)