diff --git a/waveasm/lib/Transforms/LinearScanPass.cpp b/waveasm/lib/Transforms/LinearScanPass.cpp index fef599358d..f35973b98c 100644 --- a/waveasm/lib/Transforms/LinearScanPass.cpp +++ b/waveasm/lib/Transforms/LinearScanPass.cpp @@ -33,6 +33,24 @@ namespace waveasm { #include "waveasm/Transforms/Passes.h.inc" } // namespace waveasm +/// Get a value's physical register index from the mapping, falling back to +/// the type's index for already-physical (precolored) values. +/// Returns -1 if the value has no physical register assignment. +static int64_t getEffectivePhysReg(Value value, + const PhysicalMapping &mapping) { + int64_t physReg = mapping.getPhysReg(value); + if (physReg >= 0) + return physReg; + Type ty = value.getType(); + if (auto pvreg = dyn_cast(ty)) + return pvreg.getIndex(); + if (auto pareg = dyn_cast(ty)) + return pareg.getIndex(); + if (auto psreg = dyn_cast(ty)) + return psreg.getIndex(); + return -1; +} + /// Convert a virtual register type to a physical register type. /// Also handles re-indexing an already-physical type to a new physReg. /// Returns the original type unchanged if it's not a register type @@ -236,30 +254,40 @@ struct LinearScanPass auto [mapping, stats] = *result; - // Handle waveasm.extract ops: result = source[offset] - // Set the extract result's physical register = source's physical register + - // offset + // Handle waveasm.extract ops: result = source[offset]. + // Set the extract result's physical register = source's physReg + offset. program.walk([&](ExtractOp extractOp) { - Value source = extractOp.getVector(); - Value extractResult = extractOp.getResult(); - int64_t index = extractOp.getIndex(); - - // Get source's physical register (may be precolored or allocated) - int64_t sourcePhysReg = -1; - Type srcType = source.getType(); - if (auto pvreg = dyn_cast(srcType)) { - sourcePhysReg = pvreg.getIndex(); - } else if (auto pareg = dyn_cast(srcType)) { - sourcePhysReg = pareg.getIndex(); - } else { - sourcePhysReg = mapping.getPhysReg(source); - } + int64_t sourcePhysReg = + getEffectivePhysReg(extractOp.getVector(), mapping); + if (sourcePhysReg >= 0) + mapping.setPhysReg(extractOp.getResult(), + sourcePhysReg + extractOp.getIndex()); + }); - if (sourcePhysReg >= 0) { - // Set the extract result to source + offset - mapping.setPhysReg(extractResult, sourcePhysReg + index); + // Handle waveasm.pack ops: input[i] gets result's physReg + i. + // Pack inputs were excluded from the allocation worklists during liveness + // analysis, so they have no mapping yet. Assign them here from the pack + // result's contiguous allocation. + WalkResult packResult = program.walk([&](PackOp packOp) { + int64_t resultPhysReg = getEffectivePhysReg(packOp.getResult(), mapping); + if (resultPhysReg < 0) { + packOp.emitError( + "pack result has no physical register; cannot assign inputs"); + return WalkResult::interrupt(); } + llvm::DenseSet seen; + for (auto [i, input] : llvm::enumerate(packOp.getElements())) { + if (!seen.insert(input).second) { + packOp.emitError("duplicate pack input at index ") + << i << "; each input must be a distinct value"; + return WalkResult::interrupt(); + } + mapping.setPhysReg(input, resultPhysReg + static_cast(i)); + } + return WalkResult::advance(); }); + if (packResult.wasInterrupted()) + return failure(); // Transform the IR: replace virtual register types with physical types OpBuilder builder(program.getContext()); diff --git a/waveasm/lib/Transforms/Liveness.cpp b/waveasm/lib/Transforms/Liveness.cpp index b4c3e987f8..753517ddfa 100644 --- a/waveasm/lib/Transforms/Liveness.cpp +++ b/waveasm/lib/Transforms/Liveness.cpp @@ -478,6 +478,40 @@ LivenessInfo computeLiveness(ProgramOp program) { // already handle all necessary live range extensions by directly inspecting // the region structure. + // Pass 3a: Pack group pass -- treat pack inputs as sub-registers of the + // pack result. + // + // waveasm.pack emits no assembly; it is a register allocation directive + // declaring that N inputs form a contiguous register block. The pack result + // already gets a correct contiguous allocation via allocRange, but pack + // inputs would otherwise get independent allocations to arbitrary registers. + // + // Fix: extend the pack result's live range to cover the full lifetime of + // all inputs (both defs and uses), then remove pack inputs from the + // allocation worklists entirely. A post-pass in LinearScanPass assigns + // input[i].physReg = result.physReg + i. + program.walk([&](PackOp packOp) { + Value packResult = packOp.getResult(); + auto resultIt = info.ranges.find(packResult); + assert(resultIt != info.ranges.end() && + "pack result must have a live range"); + + for (Value input : packOp.getElements()) { + // Extend the pack result's range to cover this input's full lifetime. + // Inputs may have independent uses after the pack op, so we must + // extend both start and end to avoid missing those uses. + auto inputIt = info.ranges.find(input); + if (inputIt != info.ranges.end()) { + resultIt->second.start = + std::min(resultIt->second.start, inputIt->second.start); + resultIt->second.end = + std::max(resultIt->second.end, inputIt->second.end); + // Remove from ranges so it won't enter the allocator. + info.ranges.erase(inputIt); + } + } + }); + // Pass 3b: Build tied equivalence classes for pressure de-duplication. // // LoopOp results, condition iter_args, and block args are all tied to the diff --git a/waveasm/test/Transforms/lit-regalloc-error-pack-dup.mlir b/waveasm/test/Transforms/lit-regalloc-error-pack-dup.mlir new file mode 100644 index 0000000000..28836f2687 --- /dev/null +++ b/waveasm/test/Transforms/lit-regalloc-error-pack-dup.mlir @@ -0,0 +1,16 @@ +// RUN: not waveasm-translate --waveasm-linear-scan %s 2>&1 | FileCheck %s +// +// Test: Duplicate pack inputs are rejected during register allocation. + +waveasm.program @pack_duplicate_input target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> { + %imm0 = waveasm.constant 0 : !waveasm.imm<0> + %v0 = waveasm.v_mov_b32 %imm0 : !waveasm.imm<0> -> !waveasm.vreg + + // CHECK: error: duplicate pack input at index 1; each input must be a distinct value + %packed = waveasm.pack %v0, %v0 : (!waveasm.vreg, !waveasm.vreg) -> !waveasm.vreg<2> + + %srd = waveasm.precolored.sreg 0, 4 : !waveasm.psreg<0, 4> + %voff = waveasm.precolored.vreg 0 : !waveasm.pvreg<0> + waveasm.buffer_store_dwordx2 %packed, %srd, %voff : !waveasm.vreg<2>, !waveasm.psreg<0, 4>, !waveasm.pvreg<0> + waveasm.s_endpgm +} diff --git a/waveasm/test/Transforms/lit-regalloc.mlir b/waveasm/test/Transforms/lit-regalloc.mlir index 93e3b31a39..3f5cc9e4ae 100644 --- a/waveasm/test/Transforms/lit-regalloc.mlir +++ b/waveasm/test/Transforms/lit-regalloc.mlir @@ -56,3 +56,76 @@ waveasm.program @regalloc_sgpr target = #waveasm.target<#waveasm.gfx942, 5> abi waveasm.s_endpgm } + +// CHECK-LABEL: waveasm.program @regalloc_pack_store +waveasm.program @regalloc_pack_store target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> { + // Test: Pack inputs get contiguous physical registers matching the pack result. + %srd = waveasm.precolored.sreg 0, 4 : !waveasm.psreg<0, 4> + %voff = waveasm.precolored.vreg 0 : !waveasm.pvreg<0> + %imm0 = waveasm.constant 0 : !waveasm.imm<0> + + %lo = waveasm.v_mov_b32 %imm0 : !waveasm.imm<0> -> !waveasm.vreg + %hi = waveasm.v_mov_b32 %imm0 : !waveasm.imm<0> -> !waveasm.vreg + + // Pack result should get a contiguous 2-wide allocation. + // Pack inputs should get result.physReg + 0 and result.physReg + 1. + // CHECK: waveasm.v_mov_b32 {{.*}} -> !waveasm.pvreg<[[LO:[0-9]+]]> + // CHECK: waveasm.v_mov_b32 {{.*}} -> !waveasm.pvreg<[[HI:[0-9]+]]> + // CHECK: waveasm.pack {{.*}} -> !waveasm.pvreg<[[LO]], 2> + %packed = waveasm.pack %lo, %hi : (!waveasm.vreg, !waveasm.vreg) -> !waveasm.vreg<2> + + waveasm.buffer_store_dwordx2 %packed, %srd, %voff : !waveasm.vreg<2>, !waveasm.psreg<0, 4>, !waveasm.pvreg<0> + waveasm.s_endpgm +} + +// CHECK-LABEL: waveasm.program @regalloc_pack_input_post_use +waveasm.program @regalloc_pack_input_post_use target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> { + // Test: Pack input used independently after the pack op. + // The pack result's live range must extend to cover the input's post-pack use, + // otherwise the allocator could reuse the register prematurely. + %srd = waveasm.precolored.sreg 0, 4 : !waveasm.psreg<0, 4> + %voff = waveasm.precolored.vreg 0 : !waveasm.pvreg<0> + %imm0 = waveasm.constant 0 : !waveasm.imm<0> + + %lo = waveasm.v_mov_b32 %imm0 : !waveasm.imm<0> -> !waveasm.vreg + %hi = waveasm.v_mov_b32 %imm0 : !waveasm.imm<0> -> !waveasm.vreg + + // CHECK: waveasm.pack {{.*}} -> !waveasm.pvreg<[[BASE:[0-9]+]], 2> + %packed = waveasm.pack %lo, %hi : (!waveasm.vreg, !waveasm.vreg) -> !waveasm.vreg<2> + + waveasm.buffer_store_dwordx2 %packed, %srd, %voff : !waveasm.vreg<2>, !waveasm.psreg<0, 4>, !waveasm.pvreg<0> + + // %lo is used again after the pack. The pack result's live range must cover + // this point so that %lo's physical register (BASE+0) is not reallocated. + // CHECK: waveasm.buffer_store_dword {{.*}} : !waveasm.pvreg<[[BASE]]>, + waveasm.buffer_store_dword %lo, %srd, %voff : !waveasm.vreg, !waveasm.psreg<0, 4>, !waveasm.pvreg<0> + waveasm.s_endpgm +} + +// CHECK-LABEL: waveasm.program @regalloc_extract_of_pack +waveasm.program @regalloc_extract_of_pack target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> { + // Test: Extract from pack result gets correct sub-register. + %imm0 = waveasm.constant 0 : !waveasm.imm<0> + %imm1 = waveasm.constant 1 : !waveasm.imm<1> + %imm2 = waveasm.constant 2 : !waveasm.imm<2> + %imm3 = waveasm.constant 3 : !waveasm.imm<3> + + %v0 = waveasm.v_mov_b32 %imm0 : !waveasm.imm<0> -> !waveasm.vreg + %v1 = waveasm.v_mov_b32 %imm1 : !waveasm.imm<1> -> !waveasm.vreg + %v2 = waveasm.v_mov_b32 %imm2 : !waveasm.imm<2> -> !waveasm.vreg + %v3 = waveasm.v_mov_b32 %imm3 : !waveasm.imm<3> -> !waveasm.vreg + + // CHECK: waveasm.pack {{.*}} -> !waveasm.pvreg<[[BASE:[0-9]+]], 4> + %packed = waveasm.pack %v0, %v1, %v2, %v3 : (!waveasm.vreg, !waveasm.vreg, !waveasm.vreg, !waveasm.vreg) -> !waveasm.vreg<4, 4> + + // Extract element 2 should yield physReg = pack result base + 2. + // CHECK: waveasm.extract {{.*}}[2] : !waveasm.pvreg<[[BASE]], 4> -> !waveasm.pvreg<[[ELEM:[0-9]+]]> + %elem = waveasm.extract %packed[2] : !waveasm.vreg<4, 4> -> !waveasm.vreg + + // Verify the extracted element's physical register is used in the store. + %srd = waveasm.precolored.sreg 0, 4 : !waveasm.psreg<0, 4> + %voff = waveasm.precolored.vreg 0 : !waveasm.pvreg<0> + // CHECK: waveasm.buffer_store_dword {{.*}} : !waveasm.pvreg<[[ELEM]]>, + waveasm.buffer_store_dword %elem, %srd, %voff : !waveasm.vreg, !waveasm.psreg<0, 4>, !waveasm.pvreg<0> + waveasm.s_endpgm +}