Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 48 additions & 20 deletions waveasm/lib/Transforms/LinearScanPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,24 @@ namespace waveasm {
#include "waveasm/Transforms/Passes.h.inc"
} // namespace waveasm

/// Get a value's physical register index from the mapping, falling back to
/// the type's index for already-physical (precolored) values.
/// Returns -1 if the value has no physical register assignment.
static int64_t getEffectivePhysReg(Value value,
const PhysicalMapping &mapping) {
int64_t physReg = mapping.getPhysReg(value);
if (physReg >= 0)
return physReg;
Type ty = value.getType();
if (auto pvreg = dyn_cast<PVRegType>(ty))
return pvreg.getIndex();
if (auto pareg = dyn_cast<PARegType>(ty))
return pareg.getIndex();
if (auto psreg = dyn_cast<PSRegType>(ty))
return psreg.getIndex();
return -1;
}

/// Convert a virtual register type to a physical register type.
/// Also handles re-indexing an already-physical type to a new physReg.
/// Returns the original type unchanged if it's not a register type
Expand Down Expand Up @@ -236,30 +254,40 @@ struct LinearScanPass

auto [mapping, stats] = *result;

// Handle waveasm.extract ops: result = source[offset]
// Set the extract result's physical register = source's physical register +
// offset
// Handle waveasm.extract ops: result = source[offset].
// Set the extract result's physical register = source's physReg + offset.
program.walk([&](ExtractOp extractOp) {
Value source = extractOp.getVector();
Value extractResult = extractOp.getResult();
int64_t index = extractOp.getIndex();

// Get source's physical register (may be precolored or allocated)
int64_t sourcePhysReg = -1;
Type srcType = source.getType();
if (auto pvreg = dyn_cast<PVRegType>(srcType)) {
sourcePhysReg = pvreg.getIndex();
} else if (auto pareg = dyn_cast<PARegType>(srcType)) {
sourcePhysReg = pareg.getIndex();
} else {
sourcePhysReg = mapping.getPhysReg(source);
}
int64_t sourcePhysReg =
getEffectivePhysReg(extractOp.getVector(), mapping);
if (sourcePhysReg >= 0)
mapping.setPhysReg(extractOp.getResult(),
sourcePhysReg + extractOp.getIndex());
});

if (sourcePhysReg >= 0) {
// Set the extract result to source + offset
mapping.setPhysReg(extractResult, sourcePhysReg + index);
// Handle waveasm.pack ops: input[i] gets result's physReg + i.
// Pack inputs were excluded from the allocation worklists during liveness
// analysis, so they have no mapping yet. Assign them here from the pack
// result's contiguous allocation.
WalkResult packResult = program.walk([&](PackOp packOp) {
int64_t resultPhysReg = getEffectivePhysReg(packOp.getResult(), mapping);
if (resultPhysReg < 0) {
packOp.emitError(
"pack result has no physical register; cannot assign inputs");
return WalkResult::interrupt();
}
llvm::DenseSet<Value> seen;
for (auto [i, input] : llvm::enumerate(packOp.getElements())) {
if (!seen.insert(input).second) {
packOp.emitError("duplicate pack input at index ")
<< i << "; each input must be a distinct value";
return WalkResult::interrupt();
}
mapping.setPhysReg(input, resultPhysReg + static_cast<int64_t>(i));
}
return WalkResult::advance();
});
if (packResult.wasInterrupted())
return failure();

// Transform the IR: replace virtual register types with physical types
OpBuilder builder(program.getContext());
Expand Down
34 changes: 34 additions & 0 deletions waveasm/lib/Transforms/Liveness.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,40 @@ LivenessInfo computeLiveness(ProgramOp program) {
// already handle all necessary live range extensions by directly inspecting
// the region structure.

// Pass 3a: Pack group pass -- treat pack inputs as sub-registers of the
// pack result.
//
// waveasm.pack emits no assembly; it is a register allocation directive
// declaring that N inputs form a contiguous register block. The pack result
// already gets a correct contiguous allocation via allocRange, but pack
// inputs would otherwise get independent allocations to arbitrary registers.
//
// Fix: extend the pack result's live range to cover the full lifetime of
// all inputs (both defs and uses), then remove pack inputs from the
// allocation worklists entirely. A post-pass in LinearScanPass assigns
// input[i].physReg = result.physReg + i.
program.walk([&](PackOp packOp) {
Value packResult = packOp.getResult();
auto resultIt = info.ranges.find(packResult);
assert(resultIt != info.ranges.end() &&
"pack result must have a live range");

for (Value input : packOp.getElements()) {
// Extend the pack result's range to cover this input's full lifetime.
// Inputs may have independent uses after the pack op, so we must
// extend both start and end to avoid missing those uses.
auto inputIt = info.ranges.find(input);
if (inputIt != info.ranges.end()) {
resultIt->second.start =
std::min(resultIt->second.start, inputIt->second.start);
resultIt->second.end =
std::max(resultIt->second.end, inputIt->second.end);
// Remove from ranges so it won't enter the allocator.
info.ranges.erase(inputIt);
}
}
});

// Pass 3b: Build tied equivalence classes for pressure de-duplication.
//
// LoopOp results, condition iter_args, and block args are all tied to the
Expand Down
16 changes: 16 additions & 0 deletions waveasm/test/Transforms/lit-regalloc-error-pack-dup.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// RUN: not waveasm-translate --waveasm-linear-scan %s 2>&1 | FileCheck %s
//
// Test: Duplicate pack inputs are rejected during register allocation.

waveasm.program @pack_duplicate_input target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> {
%imm0 = waveasm.constant 0 : !waveasm.imm<0>
%v0 = waveasm.v_mov_b32 %imm0 : !waveasm.imm<0> -> !waveasm.vreg

// CHECK: error: duplicate pack input at index 1; each input must be a distinct value
%packed = waveasm.pack %v0, %v0 : (!waveasm.vreg, !waveasm.vreg) -> !waveasm.vreg<2>

%srd = waveasm.precolored.sreg 0, 4 : !waveasm.psreg<0, 4>
%voff = waveasm.precolored.vreg 0 : !waveasm.pvreg<0>
waveasm.buffer_store_dwordx2 %packed, %srd, %voff : !waveasm.vreg<2>, !waveasm.psreg<0, 4>, !waveasm.pvreg<0>
waveasm.s_endpgm
}
73 changes: 73 additions & 0 deletions waveasm/test/Transforms/lit-regalloc.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,76 @@ waveasm.program @regalloc_sgpr target = #waveasm.target<#waveasm.gfx942, 5> abi

waveasm.s_endpgm
}

// CHECK-LABEL: waveasm.program @regalloc_pack_store
waveasm.program @regalloc_pack_store target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> {
// Test: Pack inputs get contiguous physical registers matching the pack result.
%srd = waveasm.precolored.sreg 0, 4 : !waveasm.psreg<0, 4>
%voff = waveasm.precolored.vreg 0 : !waveasm.pvreg<0>
%imm0 = waveasm.constant 0 : !waveasm.imm<0>

%lo = waveasm.v_mov_b32 %imm0 : !waveasm.imm<0> -> !waveasm.vreg
%hi = waveasm.v_mov_b32 %imm0 : !waveasm.imm<0> -> !waveasm.vreg

// Pack result should get a contiguous 2-wide allocation.
// Pack inputs should get result.physReg + 0 and result.physReg + 1.
// CHECK: waveasm.v_mov_b32 {{.*}} -> !waveasm.pvreg<[[LO:[0-9]+]]>
// CHECK: waveasm.v_mov_b32 {{.*}} -> !waveasm.pvreg<[[HI:[0-9]+]]>
// CHECK: waveasm.pack {{.*}} -> !waveasm.pvreg<[[LO]], 2>
%packed = waveasm.pack %lo, %hi : (!waveasm.vreg, !waveasm.vreg) -> !waveasm.vreg<2>

waveasm.buffer_store_dwordx2 %packed, %srd, %voff : !waveasm.vreg<2>, !waveasm.psreg<0, 4>, !waveasm.pvreg<0>
waveasm.s_endpgm
}

// CHECK-LABEL: waveasm.program @regalloc_pack_input_post_use
waveasm.program @regalloc_pack_input_post_use target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> {
// Test: Pack input used independently after the pack op.
// The pack result's live range must extend to cover the input's post-pack use,
// otherwise the allocator could reuse the register prematurely.
%srd = waveasm.precolored.sreg 0, 4 : !waveasm.psreg<0, 4>
%voff = waveasm.precolored.vreg 0 : !waveasm.pvreg<0>
%imm0 = waveasm.constant 0 : !waveasm.imm<0>

%lo = waveasm.v_mov_b32 %imm0 : !waveasm.imm<0> -> !waveasm.vreg
%hi = waveasm.v_mov_b32 %imm0 : !waveasm.imm<0> -> !waveasm.vreg

// CHECK: waveasm.pack {{.*}} -> !waveasm.pvreg<[[BASE:[0-9]+]], 2>
%packed = waveasm.pack %lo, %hi : (!waveasm.vreg, !waveasm.vreg) -> !waveasm.vreg<2>

waveasm.buffer_store_dwordx2 %packed, %srd, %voff : !waveasm.vreg<2>, !waveasm.psreg<0, 4>, !waveasm.pvreg<0>

// %lo is used again after the pack. The pack result's live range must cover
// this point so that %lo's physical register (BASE+0) is not reallocated.
// CHECK: waveasm.buffer_store_dword {{.*}} : !waveasm.pvreg<[[BASE]]>,
waveasm.buffer_store_dword %lo, %srd, %voff : !waveasm.vreg, !waveasm.psreg<0, 4>, !waveasm.pvreg<0>
waveasm.s_endpgm
}

// CHECK-LABEL: waveasm.program @regalloc_extract_of_pack
waveasm.program @regalloc_extract_of_pack target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> {
// Test: Extract from pack result gets correct sub-register.
%imm0 = waveasm.constant 0 : !waveasm.imm<0>
%imm1 = waveasm.constant 1 : !waveasm.imm<1>
%imm2 = waveasm.constant 2 : !waveasm.imm<2>
%imm3 = waveasm.constant 3 : !waveasm.imm<3>

%v0 = waveasm.v_mov_b32 %imm0 : !waveasm.imm<0> -> !waveasm.vreg
%v1 = waveasm.v_mov_b32 %imm1 : !waveasm.imm<1> -> !waveasm.vreg
%v2 = waveasm.v_mov_b32 %imm2 : !waveasm.imm<2> -> !waveasm.vreg
%v3 = waveasm.v_mov_b32 %imm3 : !waveasm.imm<3> -> !waveasm.vreg

// CHECK: waveasm.pack {{.*}} -> !waveasm.pvreg<[[BASE:[0-9]+]], 4>
%packed = waveasm.pack %v0, %v1, %v2, %v3 : (!waveasm.vreg, !waveasm.vreg, !waveasm.vreg, !waveasm.vreg) -> !waveasm.vreg<4, 4>

// Extract element 2 should yield physReg = pack result base + 2.
// CHECK: waveasm.extract {{.*}}[2] : !waveasm.pvreg<[[BASE]], 4> -> !waveasm.pvreg<[[ELEM:[0-9]+]]>
%elem = waveasm.extract %packed[2] : !waveasm.vreg<4, 4> -> !waveasm.vreg

// Verify the extracted element's physical register is used in the store.
%srd = waveasm.precolored.sreg 0, 4 : !waveasm.psreg<0, 4>
%voff = waveasm.precolored.vreg 0 : !waveasm.pvreg<0>
// CHECK: waveasm.buffer_store_dword {{.*}} : !waveasm.pvreg<[[ELEM]]>,
waveasm.buffer_store_dword %elem, %srd, %voff : !waveasm.vreg, !waveasm.psreg<0, 4>, !waveasm.pvreg<0>
waveasm.s_endpgm
}
Loading