Skip to content

[Reland][RISCV] Refactor X60 scheduling model helper classes. NFC. #152336

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 7, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 93 additions & 70 deletions llvm/lib/Target/RISCV/RISCVSchedSpacemitX60.td
Original file line number Diff line number Diff line change
Expand Up @@ -13,78 +13,113 @@
//
//===----------------------------------------------------------------------===//

class SMX60IsWorstCaseMX<string mx, list<string> MxList> {
string LLMUL = LargestLMUL<MxList>.r;
bit c = !eq(mx, LLMUL);
}
//===----------------------------------------------------------------------===//
// Helpers

// Maps LMUL string to corresponding value from the Values array
// LMUL values map to array indices as follows:
// MF8 -> Values[0], MF4 -> Values[1], MF2 -> Values[2], M1 -> Values[3],
// M2 -> Values[4], M4 -> Values[5], M8 -> Values[6]
// Shorter lists are allowed, e.g., widening instructions don't work on M8
class GetLMULValue<list<int> Values, string LMUL> {
defvar Index = !cond(
!eq(LMUL, "MF8"): 0,
!eq(LMUL, "MF4"): 1,
!eq(LMUL, "MF2"): 2,
!eq(LMUL, "M1"): 3,
!eq(LMUL, "M2"): 4,
!eq(LMUL, "M4"): 5,
!eq(LMUL, "M8"): 6,
);

class SMX60IsWorstCaseMXSEW<string mx, int sew, list<string> MxList, bit isF = 0> {
string LLMUL = LargestLMUL<MxList>.r;
int SSEW = SmallestSEW<mx, isF>.r;
bit c = !and(!eq(mx, LLMUL), !eq(sew, SSEW));
assert !lt(Index, !size(Values)),
"Missing LMUL value for '" # LMUL # "'. " #
"Expected at least " # !add(Index, 1) # " elements, but got " #
!size(Values) # ".";

int c = Values[Index];
}

defvar SMX60VLEN = 256;
defvar SMX60DLEN = !div(SMX60VLEN, 2);
// Returns BaseValue for LMUL values before startLMUL, Value for startLMUL,
// then doubles Value for each subsequent LMUL
// Example: ConstValueUntilLMULThenDoubleBase<"M1", 2, 4, "M8"> returns:
// MF8->2, MF4->2, MF2->2, M1->4, M2->8, M4->16, M8->32
// This is useful for modeling scheduling parameters that scale with LMUL.
class ConstValueUntilLMULThenDoubleBase<string startLMUL, int BaseValue, int Value, string currentLMUL> {
assert !le(BaseValue, Value), "BaseValue must be less-equal to Value";
defvar startPos = GetLMULValue<[0, 1, 2, 3, 4, 5, 6], startLMUL>.c;
defvar currentPos = GetLMULValue<[0, 1, 2, 3, 4, 5, 6], currentLMUL>.c;

class Get1248Latency<string mx> {
// Calculate the difference in positions
defvar posDiff = !sub(currentPos, startPos);

// Calculate Value * (2^posDiff)
int c = !cond(
!eq(mx, "M2") : 2,
!eq(mx, "M4") : 4,
!eq(mx, "M8") : 8,
true: 1
!eq(posDiff, 0) : Value,
!eq(posDiff, 1) : !mul(Value, 2),
!eq(posDiff, 2) : !mul(Value, 4),
!eq(posDiff, 3) : !mul(Value, 8),
!eq(posDiff, 4) : !mul(Value, 16),
!eq(posDiff, 5) : !mul(Value, 32),
!eq(posDiff, 6) : !mul(Value, 64),
true : BaseValue
);
}

// Used for: logical opsz, shifts, sign ext, merge/move, FP sign/recip/convert, mask ops, slides
class Get4816Latency<string mx> {
int c = !cond(
!eq(mx, "M4") : 8,
!eq(mx, "M8") : 16,
true: 4
);
// Same as the previous function but BaseValue == Value
class ConstValueUntilLMULThenDouble<string startLMUL, int Value, string currentLMUL> {
int c = ConstValueUntilLMULThenDoubleBase<startLMUL, Value, Value, currentLMUL>.c;
}

// Returns MF8->1, MF4->1, MF2->2, M1->4, M2->8, M4->16, M8->32
class ConstOneUntilMF4ThenDouble<string mx> {
int c = ConstValueUntilLMULThenDouble<"MF4", 1, mx>.c;
}

// Returns MF8->1, MF4->1, MF2->1, M1->2, M2->4, M4->8, M8->16
class ConstOneUntilMF2ThenDouble<string mx> {
int c = ConstValueUntilLMULThenDouble<"MF2", 1, mx>.c;
}

// Returns MF8->1, MF4->1, MF2->1, M1->1, M2->2, M4->4, M8->8
class ConstOneUntilM1ThenDouble<string mx> {
int c = ConstValueUntilLMULThenDouble<"M1", 1, mx>.c;
}

//===----------------------------------------------------------------------===//
// Latency helper classes

// Used for: arithmetic (add/sub/min/max), saturating/averaging, FP add/sub/min/max
class Get458Latency<string mx> {
int c = !cond(
!eq(mx, "M4") : 5,
!eq(mx, "M8") : 8,
true: 4
);
class Get4458Latency<string mx> {
int c = GetLMULValue<[/*MF8=*/4, /*MF4=*/4, /*MF2=*/4, /*M1=*/4, /*M2=*/4, /*M4=*/5, /*M8=*/8], mx>.c;
}

// Widening scaling pattern (4,4,4,4,5,8,8): plateaus at higher LMULs
// Used for: widening operations
// Used for: widening operations (no M8)
class Get4588Latency<string mx> {
int c = !cond(
!eq(mx, "M2") : 5,
!eq(mx, "M4") : 8,
!eq(mx, "M8") : 8, // M8 not supported for most widening, fallback
true: 4
);
int c = GetLMULValue<[/*MF8=*/4, /*MF4=*/4, /*MF2=*/4, /*M1=*/4, /*M2=*/5, /*M4=*/8], mx>.c;
}

// Used for: mask-producing comparisons, carry ops with mask, FP comparisons
class Get461018Latency<string mx> {
int c = !cond(
!eq(mx, "M2") : 6,
!eq(mx, "M4") : 10,
!eq(mx, "M8") : 18,
true: 4
);
int c = GetLMULValue<[/*MF8=*/4, /*MF4=*/4, /*MF2=*/4, /*M1=*/4, /*M2=*/6, /*M4=*/10, /*M8=*/18], mx>.c;
}

// Used for: e64 multiply pattern, complex ops
class Get781632Latency<string mx> {
int c = !cond(
!eq(mx, "M2") : 8,
!eq(mx, "M4") : 16,
!eq(mx, "M8") : 32,
true: 7
);
//===----------------------------------------------------------------------===//

class SMX60IsWorstCaseMX<string mx, list<string> MxList> {
string LLMUL = LargestLMUL<MxList>.r;
bit c = !eq(mx, LLMUL);
}

class SMX60IsWorstCaseMXSEW<string mx, int sew, list<string> MxList, bit isF = 0> {
string LLMUL = LargestLMUL<MxList>.r;
int SSEW = SmallestSEW<mx, isF>.r;
bit c = !and(!eq(mx, LLMUL), !eq(sew, SSEW));
}

defvar SMX60VLEN = 256;
defvar SMX60DLEN = !div(SMX60VLEN, 2);

def SpacemitX60Model : SchedMachineModel {
let IssueWidth = 2; // dual-issue
let MicroOpBufferSize = 0; // in-order
Expand Down Expand Up @@ -383,12 +418,13 @@ foreach LMul = [1, 2, 4, 8] in {
foreach mx = SchedMxList in {
defvar IsWorstCase = SMX60IsWorstCaseMX<mx, SchedMxList>.c;

let Latency = Get458Latency<mx>.c, ReleaseAtCycles = [4] in {
let Latency = Get4458Latency<mx>.c, ReleaseAtCycles = [4] in {
defm "" : LMULWriteResMX<"WriteVIMinMaxV", [SMX60_VIEU], mx, IsWorstCase>;
defm "" : LMULWriteResMX<"WriteVIMinMaxX", [SMX60_VIEU], mx, IsWorstCase>;
}

let Latency = Get4816Latency<mx>.c, ReleaseAtCycles = [4] in {
defvar VIALULat = ConstValueUntilLMULThenDouble<"M2", 4, mx>.c;
let Latency = VIALULat, ReleaseAtCycles = [4] in {
// Pattern of vadd, vsub, vrsub: 4/4/5/8
// Pattern of vand, vor, vxor: 4/4/8/16
// They are grouped together, so we used the worst case 4/4/8/16
Expand Down Expand Up @@ -425,7 +461,7 @@ foreach mx = SchedMxList in {
// Pattern of vmacc, vmadd, vmul, vmulh, etc.: e8/e16 = 4/4/5/8, e32 = 5,5,5,8,
// e64 = 7,8,16,32. We use the worst-case until we can split the SEW.
// TODO: change WriteVIMulV, etc to be defined with LMULSEWSchedWrites
let Latency = Get781632Latency<mx>.c, ReleaseAtCycles = [7] in {
let Latency = ConstValueUntilLMULThenDoubleBase<"M2", 7, 8, mx>.c, ReleaseAtCycles = [7] in {
defm "" : LMULWriteResMX<"WriteVIMulV", [SMX60_VIEU], mx, IsWorstCase>;
defm "" : LMULWriteResMX<"WriteVIMulX", [SMX60_VIEU], mx, IsWorstCase>;
defm "" : LMULWriteResMX<"WriteVIMulAddV", [SMX60_VIEU], mx, IsWorstCase>;
Expand Down Expand Up @@ -461,15 +497,8 @@ foreach mx = SchedMxList in {
foreach sew = SchedSEWSet<mx>.val in {
defvar IsWorstCase = SMX60IsWorstCaseMXSEW<mx, sew, SchedMxList>.c;

// Slightly reduced for fractional LMULs
defvar Multiplier = !cond(
!eq(mx, "MF8") : 12,
!eq(mx, "MF4") : 12,
!eq(mx, "MF2") : 12,
true: 24
);

let Latency = !mul(Get1248Latency<mx>.c, Multiplier), ReleaseAtCycles = [12] in {
defvar VIDivLat = ConstValueUntilLMULThenDouble<"MF2", 12, mx>.c;
let Latency = VIDivLat, ReleaseAtCycles = [12] in {
defm "" : LMULSEWWriteResMXSEW<"WriteVIDivV", [SMX60_VIEU], mx, sew, IsWorstCase>;
defm "" : LMULSEWWriteResMXSEW<"WriteVIDivX", [SMX60_VIEU], mx, sew, IsWorstCase>;
}
Expand All @@ -480,14 +509,8 @@ foreach mx = SchedMxList in {
foreach mx = SchedMxListW in {
defvar IsWorstCase = SMX60IsWorstCaseMX<mx, SchedMxListW>.c;

// Slightly increased for integer LMULs
defvar Multiplier = !cond(
!eq(mx, "M2") : 2,
!eq(mx, "M4") : 2,
true: 1
);

let Latency = !mul(Get4816Latency<mx>.c, Multiplier), ReleaseAtCycles = [4] in {
defvar VNarrowingLat = ConstValueUntilLMULThenDouble<"M1", 4, mx>.c;
let Latency = VNarrowingLat, ReleaseAtCycles = [4] in {
defm "" : LMULWriteResMX<"WriteVNShiftV", [SMX60_VIEU], mx, IsWorstCase>;
defm "" : LMULWriteResMX<"WriteVNShiftX", [SMX60_VIEU], mx, IsWorstCase>;
defm "" : LMULWriteResMX<"WriteVNShiftI", [SMX60_VIEU], mx, IsWorstCase>;
Expand Down