Skip to content

Commit dbe05c6

Browse files
Hardcode84claude
andcommitted
Handle PackOp inputs in regalloc liveness and linear scan
PackOp is a register allocation directive: its N inputs must form a contiguous register block matching the pack result. Previously, pack inputs got independent allocations to arbitrary registers while the result got a correct contiguous allocation, leaving downstream consumers reading uninitialized physical registers. Fix by treating pack inputs as sub-registers of the pack result: - Liveness: extend the pack result's live range backwards to cover input defs, then remove inputs from allocation worklists. - LinearScanPass: post-pass assigns input[i].physReg = result + i, mirroring the existing ExtractOp post-pass. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
1 parent 6560184 commit dbe05c6

File tree

4 files changed

+142
-20
lines changed

4 files changed

+142
-20
lines changed

waveasm/lib/Transforms/LinearScanPass.cpp

Lines changed: 48 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,24 @@ namespace waveasm {
3333
#include "waveasm/Transforms/Passes.h.inc"
3434
} // namespace waveasm
3535

36+
/// Get a value's physical register index from the mapping, falling back to
37+
/// the type's index for already-physical (precolored) values.
38+
/// Returns -1 if the value has no physical register assignment.
39+
static int64_t getEffectivePhysReg(Value value,
40+
const PhysicalMapping &mapping) {
41+
int64_t physReg = mapping.getPhysReg(value);
42+
if (physReg >= 0)
43+
return physReg;
44+
Type ty = value.getType();
45+
if (auto pvreg = dyn_cast<PVRegType>(ty))
46+
return pvreg.getIndex();
47+
if (auto pareg = dyn_cast<PARegType>(ty))
48+
return pareg.getIndex();
49+
if (auto psreg = dyn_cast<PSRegType>(ty))
50+
return psreg.getIndex();
51+
return -1;
52+
}
53+
3654
/// Convert a virtual register type to a physical register type.
3755
/// Also handles re-indexing an already-physical type to a new physReg.
3856
/// Returns the original type unchanged if it's not a register type
@@ -236,30 +254,40 @@ struct LinearScanPass
236254

237255
auto [mapping, stats] = *result;
238256

239-
// Handle waveasm.extract ops: result = source[offset]
240-
// Set the extract result's physical register = source's physical register +
241-
// offset
257+
// Handle waveasm.extract ops: result = source[offset].
258+
// Set the extract result's physical register = source's physReg + offset.
242259
program.walk([&](ExtractOp extractOp) {
243-
Value source = extractOp.getVector();
244-
Value extractResult = extractOp.getResult();
245-
int64_t index = extractOp.getIndex();
246-
247-
// Get source's physical register (may be precolored or allocated)
248-
int64_t sourcePhysReg = -1;
249-
Type srcType = source.getType();
250-
if (auto pvreg = dyn_cast<PVRegType>(srcType)) {
251-
sourcePhysReg = pvreg.getIndex();
252-
} else if (auto pareg = dyn_cast<PARegType>(srcType)) {
253-
sourcePhysReg = pareg.getIndex();
254-
} else {
255-
sourcePhysReg = mapping.getPhysReg(source);
256-
}
260+
int64_t sourcePhysReg =
261+
getEffectivePhysReg(extractOp.getVector(), mapping);
262+
if (sourcePhysReg >= 0)
263+
mapping.setPhysReg(extractOp.getResult(),
264+
sourcePhysReg + extractOp.getIndex());
265+
});
257266

258-
if (sourcePhysReg >= 0) {
259-
// Set the extract result to source + offset
260-
mapping.setPhysReg(extractResult, sourcePhysReg + index);
267+
// Handle waveasm.pack ops: input[i] gets result's physReg + i.
268+
// Pack inputs were excluded from the allocation worklists during liveness
269+
// analysis, so they have no mapping yet. Assign them here from the pack
270+
// result's contiguous allocation.
271+
WalkResult packResult = program.walk([&](PackOp packOp) {
272+
int64_t resultPhysReg = getEffectivePhysReg(packOp.getResult(), mapping);
273+
if (resultPhysReg < 0) {
274+
packOp.emitError(
275+
"pack result has no physical register; cannot assign inputs");
276+
return WalkResult::interrupt();
261277
}
278+
llvm::DenseSet<Value> seen;
279+
for (auto [i, input] : llvm::enumerate(packOp.getElements())) {
280+
if (!seen.insert(input).second) {
281+
packOp.emitError("duplicate pack input at index ")
282+
<< i << "; each input must be a distinct value";
283+
return WalkResult::interrupt();
284+
}
285+
mapping.setPhysReg(input, resultPhysReg + static_cast<int64_t>(i));
286+
}
287+
return WalkResult::advance();
262288
});
289+
if (packResult.wasInterrupted())
290+
return failure();
263291

264292
// Transform the IR: replace virtual register types with physical types
265293
OpBuilder builder(program.getContext());

waveasm/lib/Transforms/Liveness.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,35 @@ LivenessInfo computeLiveness(ProgramOp program) {
478478
// already handle all necessary live range extensions by directly inspecting
479479
// the region structure.
480480

481+
// Pass 3a: Pack group pass -- treat pack inputs as sub-registers of the
482+
// pack result.
483+
//
484+
// waveasm.pack emits no assembly; it is a register allocation directive
485+
// declaring that N inputs form a contiguous register block. The pack result
486+
// already gets a correct contiguous allocation via allocRange, but pack
487+
// inputs would otherwise get independent allocations to arbitrary registers.
488+
//
489+
// Fix: extend the pack result's live range backwards to cover input defs,
490+
// then remove pack inputs from the allocation worklists entirely. A
491+
// post-pass in LinearScanPass assigns input[i].physReg = result.physReg + i.
492+
program.walk([&](PackOp packOp) {
493+
Value packResult = packOp.getResult();
494+
auto resultIt = info.ranges.find(packResult);
495+
assert(resultIt != info.ranges.end() &&
496+
"pack result must have a live range");
497+
498+
for (Value input : packOp.getElements()) {
499+
// Extend the pack result's range start to cover this input's def.
500+
auto inputIt = info.ranges.find(input);
501+
if (inputIt != info.ranges.end()) {
502+
resultIt->second.start =
503+
std::min(resultIt->second.start, inputIt->second.start);
504+
// Remove from ranges so it won't enter the allocator.
505+
info.ranges.erase(inputIt);
506+
}
507+
}
508+
});
509+
481510
// Pass 3b: Build tied equivalence classes for pressure de-duplication.
482511
//
483512
// LoopOp results, condition iter_args, and block args are all tied to the
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// RUN: not waveasm-translate --waveasm-linear-scan %s 2>&1 | FileCheck %s
2+
//
3+
// Test: Duplicate pack inputs are rejected during register allocation.
4+
5+
waveasm.program @pack_duplicate_input target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> {
6+
%imm0 = waveasm.constant 0 : !waveasm.imm<0>
7+
%v0 = waveasm.v_mov_b32 %imm0 : !waveasm.imm<0> -> !waveasm.vreg
8+
9+
// CHECK: error: duplicate pack input at index 1; each input must be a distinct value
10+
%packed = waveasm.pack %v0, %v0 : (!waveasm.vreg, !waveasm.vreg) -> !waveasm.vreg<2>
11+
12+
%srd = waveasm.precolored.sreg 0, 4 : !waveasm.psreg<0, 4>
13+
%voff = waveasm.precolored.vreg 0 : !waveasm.pvreg<0>
14+
waveasm.buffer_store_dwordx2 %packed, %srd, %voff : !waveasm.vreg<2>, !waveasm.psreg<0, 4>, !waveasm.pvreg<0>
15+
waveasm.s_endpgm
16+
}

waveasm/test/Transforms/lit-regalloc.mlir

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,52 @@ waveasm.program @regalloc_sgpr target = #waveasm.target<#waveasm.gfx942, 5> abi
5656

5757
waveasm.s_endpgm
5858
}
59+
60+
// CHECK-LABEL: waveasm.program @regalloc_pack_store
61+
waveasm.program @regalloc_pack_store target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> {
62+
// Test: Pack inputs get contiguous physical registers matching the pack result.
63+
%srd = waveasm.precolored.sreg 0, 4 : !waveasm.psreg<0, 4>
64+
%voff = waveasm.precolored.vreg 0 : !waveasm.pvreg<0>
65+
%imm0 = waveasm.constant 0 : !waveasm.imm<0>
66+
67+
%lo = waveasm.v_mov_b32 %imm0 : !waveasm.imm<0> -> !waveasm.vreg
68+
%hi = waveasm.v_mov_b32 %imm0 : !waveasm.imm<0> -> !waveasm.vreg
69+
70+
// Pack result should get a contiguous 2-wide allocation.
71+
// Pack inputs should get result.physReg + 0 and result.physReg + 1.
72+
// CHECK: waveasm.v_mov_b32 {{.*}} -> !waveasm.pvreg<[[LO:[0-9]+]]>
73+
// CHECK: waveasm.v_mov_b32 {{.*}} -> !waveasm.pvreg<[[HI:[0-9]+]]>
74+
// CHECK: waveasm.pack {{.*}} -> !waveasm.pvreg<[[LO]], 2>
75+
%packed = waveasm.pack %lo, %hi : (!waveasm.vreg, !waveasm.vreg) -> !waveasm.vreg<2>
76+
77+
waveasm.buffer_store_dwordx2 %packed, %srd, %voff : !waveasm.vreg<2>, !waveasm.psreg<0, 4>, !waveasm.pvreg<0>
78+
waveasm.s_endpgm
79+
}
80+
81+
// CHECK-LABEL: waveasm.program @regalloc_extract_of_pack
82+
waveasm.program @regalloc_extract_of_pack target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> {
83+
// Test: Extract from pack result gets correct sub-register.
84+
%imm0 = waveasm.constant 0 : !waveasm.imm<0>
85+
%imm1 = waveasm.constant 1 : !waveasm.imm<1>
86+
%imm2 = waveasm.constant 2 : !waveasm.imm<2>
87+
%imm3 = waveasm.constant 3 : !waveasm.imm<3>
88+
89+
%v0 = waveasm.v_mov_b32 %imm0 : !waveasm.imm<0> -> !waveasm.vreg
90+
%v1 = waveasm.v_mov_b32 %imm1 : !waveasm.imm<1> -> !waveasm.vreg
91+
%v2 = waveasm.v_mov_b32 %imm2 : !waveasm.imm<2> -> !waveasm.vreg
92+
%v3 = waveasm.v_mov_b32 %imm3 : !waveasm.imm<3> -> !waveasm.vreg
93+
94+
// CHECK: waveasm.pack {{.*}} -> !waveasm.pvreg<[[BASE:[0-9]+]], 4>
95+
%packed = waveasm.pack %v0, %v1, %v2, %v3 : (!waveasm.vreg, !waveasm.vreg, !waveasm.vreg, !waveasm.vreg) -> !waveasm.vreg<4, 4>
96+
97+
// Extract element 2 should yield physReg = pack result base + 2.
98+
// CHECK: waveasm.extract {{.*}}[2] : !waveasm.pvreg<[[BASE]], 4> -> !waveasm.pvreg<[[ELEM:[0-9]+]]>
99+
%elem = waveasm.extract %packed[2] : !waveasm.vreg<4, 4> -> !waveasm.vreg
100+
101+
// Verify the extracted element's physical register is used in the store.
102+
%srd = waveasm.precolored.sreg 0, 4 : !waveasm.psreg<0, 4>
103+
%voff = waveasm.precolored.vreg 0 : !waveasm.pvreg<0>
104+
// CHECK: waveasm.buffer_store_dword {{.*}} : !waveasm.pvreg<[[ELEM]]>,
105+
waveasm.buffer_store_dword %elem, %srd, %voff : !waveasm.vreg, !waveasm.psreg<0, 4>, !waveasm.pvreg<0>
106+
waveasm.s_endpgm
107+
}

0 commit comments

Comments
 (0)