Skip to content

Commit 1620f0c

Browse files
authored
Implement more Vega instructions
1 parent c91c772 commit 1620f0c

File tree

1 file changed

+295
-29
lines changed

1 file changed

+295
-29
lines changed

vega-experiments/aco_optimizer.cpp

Lines changed: 295 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2109,6 +2109,262 @@ combine_three_valu_op(opt_ctx& ctx, aco_ptr<Instruction>& instr, aco_opcode op2,
21092109
return false;
21102110
}
21112111

2112+
/* Recognise (x >> imm) | (y << (32-imm)) -> v_alignbit_b32 y, x, imm
2113+
* and (x >> (8*k)) | (y << (32-8*k)) -> v_alignbyte_b32 y, x, k (0-3)
2114+
*
2115+
* Restrictions:
2116+
* ‑ both shifts must be *_lsh?rev_b32 with a literal shift amount
2117+
* ‑ the shift results must be used only by the OR we are visiting
2118+
* ‑ no modifiers (abs/neg/…)
2119+
* ‑ imm ∈ [1, 31] (0/32 are nop; the OR will already have been DCEd)
2120+
*/
2121+
static bool
2122+
combine_alignbit_like(opt_ctx& ctx,
2123+
aco_ptr<Instruction>& or_instr,
2124+
aco_opcode target_opcode,
2125+
unsigned granularity /* 1 for bit, 8 for byte */)
2126+
{
2127+
/* Sanity: are we at the correct opcode and granularity? */
2128+
if (or_instr->operands.size() != 2 ||
2129+
or_instr->opcode != aco_opcode::v_or_b32 ||
2130+
granularity == 0)
2131+
return false;
2132+
2133+
/* Helper lambdas ------------------------------------------------------ */
2134+
auto is_shift_candidate =
2135+
[&](Operand& op,
2136+
unsigned* shift /*out*/,
2137+
Operand* value /*out*/,
2138+
bool* is_shr /*out*/) -> bool
2139+
{
2140+
if (!op.isTemp())
2141+
return false;
2142+
2143+
Instruction* sh = ctx.info[op.tempId()].parent_instr;
2144+
if (!sh || ctx.uses[op.tempId()] != 1)
2145+
return false;
2146+
2147+
/* match v_lshrrev_b32 imm, src (shift right)
2148+
* or v_lshlrev_b32 imm, src (shift left ) */
2149+
bool shr = false;
2150+
if (sh->opcode == aco_opcode::v_lshrrev_b32)
2151+
shr = true;
2152+
else if (sh->opcode == aco_opcode::v_lshlrev_b32)
2153+
shr = false;
2154+
else
2155+
return false;
2156+
2157+
if (!sh->operands[0].isLiteral() || sh->operands[0].constantValue() >= 32)
2158+
return false;
2159+
2160+
unsigned imm = sh->operands[0].constantValue();
2161+
if (imm == 0 || imm >= 32 || (imm % granularity))
2162+
return false; /* Either no-op or wrong granularity */
2163+
2164+
/* good, fill the out-params */
2165+
if (shift) *shift = imm;
2166+
if (value) *value = sh->operands[1]; /* real source (second operand) */
2167+
if (is_shr) *is_shr = shr;
2168+
return true;
2169+
};
2170+
2171+
/* Try to match both operands ----------------------------------------- */
2172+
unsigned s0_shift = 0, s1_shift = 0;
2173+
Operand s0_value, s1_value;
2174+
bool s0_is_shr = false, s1_is_shr = false;
2175+
2176+
if (!is_shift_candidate(or_instr->operands[0], &s0_shift, &s0_value, &s0_is_shr) ||
2177+
!is_shift_candidate(or_instr->operands[1], &s1_shift, &s1_value, &s1_is_shr))
2178+
return false;
2179+
2180+
/* We need one left and one right shift and complementary amounts ------*/
2181+
if (s0_is_shr == s1_is_shr)
2182+
return false; /* both left or both right */
2183+
2184+
unsigned imm = s0_is_shr ? s0_shift : s1_shift; /* right-shift amount */
2185+
unsigned limm = s0_is_shr ? s1_shift : s0_shift; /* left-shift amount */
2186+
if (imm + limm != 32 || imm % granularity)
2187+
return false; /* not complementary */
2188+
2189+
/* Build replacement instruction -------------------------------------- */
2190+
aco_ptr<Instruction> new_instr{
2191+
create_instruction(target_opcode, Format::VOP3, 3, 1)};
2192+
2193+
/* According to ISA: dst = (src0 << imm) | (src1 >> (N-imm))
2194+
* -> src0 = left-shift source, src1 = right-shift source */
2195+
Operand left_src = s0_is_shr ? s1_value : s0_value;
2196+
Operand right_src = s0_is_shr ? s0_value : s1_value;
2197+
2198+
new_instr->operands[0] = left_src;
2199+
new_instr->operands[1] = right_src;
2200+
new_instr->operands[2] = Operand::c32(imm / granularity); /* byte shift if granularity==8 */
2201+
2202+
new_instr->definitions[0] = or_instr->definitions[0];
2203+
new_instr->pass_flags = or_instr->pass_flags;
2204+
2205+
/* Update SSA-use information ----------------------------------------- */
2206+
ctx.uses[new_instr->operands[0].isTemp() ? new_instr->operands[0].tempId() : 0]++;
2207+
ctx.uses[new_instr->operands[1].isTemp() ? new_instr->operands[1].tempId() : 0]++;
2208+
2209+
ctx.uses[or_instr->operands[0].tempId()]--;
2210+
ctx.uses[or_instr->operands[1].tempId()]--;
2211+
2212+
/* Install the new instruction and fix bookkeeping -------------------- */
2213+
or_instr = std::move(new_instr);
2214+
ctx.info[or_instr->definitions[0].tempId()].parent_instr = or_instr.get();
2215+
return true;
2216+
}
2217+
2218+
static inline bool
2219+
combine_alignbit_b32(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2220+
{
2221+
return combine_alignbit_like(ctx, instr, aco_opcode::v_alignbit_b32, 1 /*bit*/);
2222+
}
2223+
2224+
static inline bool
2225+
combine_alignbyte_b32(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2226+
{
2227+
return combine_alignbit_like(ctx, instr, aco_opcode::v_alignbyte_b32, 8 /*byte*/);
2228+
}
2229+
2230+
static bool
2231+
combine_bfi_b32(opt_ctx& ctx, aco_ptr<Instruction>& or_instr)
2232+
{
2233+
if (or_instr->opcode != aco_opcode::v_or_b32 ||
2234+
ctx.program->gfx_level < GFX9)
2235+
return false;
2236+
2237+
auto match_and =
2238+
[&](Operand& in,
2239+
Operand* val /*out*/,
2240+
uint32_t* lit /*out*/) -> bool
2241+
{
2242+
if (!in.isTemp())
2243+
return false;
2244+
2245+
Instruction* and_instr = ctx.info[in.tempId()].parent_instr;
2246+
if (!and_instr ||
2247+
ctx.uses[in.tempId()] != 1 ||
2248+
and_instr->opcode != aco_opcode::v_and_b32)
2249+
return false;
2250+
2251+
/* One operand must be literal, the other is the real source. */
2252+
if (and_instr->operands[0].isLiteral()) {
2253+
*lit = and_instr->operands[0].constantValue();
2254+
*val = and_instr->operands[1];
2255+
} else if (and_instr->operands[1].isLiteral()) {
2256+
*lit = and_instr->operands[1].constantValue();
2257+
*val = and_instr->operands[0];
2258+
} else {
2259+
return false;
2260+
}
2261+
return true;
2262+
};
2263+
2264+
/* Try to match both OR operands … */
2265+
Operand base, ins;
2266+
uint32_t lit_base_mask = 0, lit_ins_mask = 0;
2267+
2268+
if (!match_and(or_instr->operands[0], &base, &lit_base_mask) ||
2269+
!match_and(or_instr->operands[1], &ins, &lit_ins_mask))
2270+
return false;
2271+
2272+
/* … and make sure the masks are complements. */
2273+
if ((lit_base_mask ^ lit_ins_mask) != 0xffffffffu)
2274+
return false;
2275+
2276+
/* Decide which part is the insert and which the base so that
2277+
* dst = (mask & ins) | (~mask & base). */
2278+
Operand op_base, op_ins;
2279+
uint32_t mask = 0;
2280+
if ((lit_ins_mask & lit_base_mask) == 0)
2281+
{
2282+
/* One is mask, the other is ~mask. We need (mask & ins). */
2283+
if (lit_ins_mask < lit_base_mask) {
2284+
mask = lit_ins_mask;
2285+
op_ins = ins;
2286+
op_base = base;
2287+
} else {
2288+
mask = lit_base_mask;
2289+
op_ins = base;
2290+
op_base = ins;
2291+
}
2292+
} else {
2293+
/* Not complementary in the expected way. */
2294+
return false;
2295+
}
2296+
2297+
/* Build v_bfi_b32 --------------------------------------------------- */
2298+
aco_ptr<Instruction> bfi{
2299+
create_instruction(aco_opcode::v_bfi_b32, Format::VOP3, 3, 1)};
2300+
2301+
bfi->operands[0] = op_base; /* src0 */
2302+
bfi->operands[1] = op_ins; /* src1 */
2303+
bfi->operands[2] = Operand::c32(mask); /* src2 (mask) */
2304+
2305+
bfi->definitions[0] = or_instr->definitions[0];
2306+
bfi->pass_flags = or_instr->pass_flags;
2307+
2308+
/* Update use-counts (the two &-results disappear) ------------------ */
2309+
ctx.uses[or_instr->operands[0].tempId()]--;
2310+
ctx.uses[or_instr->operands[1].tempId()]--;
2311+
2312+
/* Insert the new instruction, patch SSA info ----------------------- */
2313+
or_instr = std::move(bfi);
2314+
ctx.info[or_instr->definitions[0].tempId()].parent_instr = or_instr.get();
2315+
return true;
2316+
}
2317+
2318+
static bool
2319+
combine_bfe_b32(opt_ctx& ctx, aco_ptr<Instruction>& and_instr)
2320+
{
2321+
if (and_instr->opcode != aco_opcode::v_and_b32 ||
2322+
ctx.program->gfx_level < GFX8)
2323+
return false;
2324+
2325+
/* literal contiguous mask? ----------------------------------- */
2326+
if (!and_instr->operands[1].isLiteral())
2327+
return false;
2328+
uint32_t mask = and_instr->operands[1].constantValue();
2329+
if (!mask || (mask & (mask + 1u)) != 0)
2330+
return false; /* not 0…01…1 */
2331+
2332+
unsigned width = util_bitcount(mask); /* popcount == width */
2333+
if (width > 32)
2334+
return false;
2335+
2336+
/* unique user and upstream shift ------------------------------ */
2337+
if (!and_instr->operands[0].isTemp() ||
2338+
ctx.uses[and_instr->operands[0].tempId()] != 1)
2339+
return false;
2340+
Instruction* sh = ctx.info[and_instr->operands[0].tempId()].parent_instr;
2341+
if (!sh || !(sh->opcode == aco_opcode::v_lshrrev_b32 ||
2342+
sh->opcode == aco_opcode::v_ashrrev_i32))
2343+
return false;
2344+
if (!sh->operands[0].isLiteral() ||
2345+
sh->operands[0].constantValue() >= 32)
2346+
return false;
2347+
2348+
unsigned offset = sh->operands[0].constantValue();
2349+
2350+
/* build bfe ---------------------------------------------------- */
2351+
aco_opcode bfe_op = sh->opcode == aco_opcode::v_lshrrev_b32 ?
2352+
aco_opcode::v_bfe_u32 : aco_opcode::v_bfe_i32;
2353+
2354+
aco_ptr<Instruction> bfe{create_instruction(bfe_op, Format::VOP3, 3, 1)};
2355+
bfe->operands[0] = sh->operands[1]; /* src */
2356+
bfe->operands[1] = Operand::c32(offset); /* offset */
2357+
bfe->operands[2] = Operand::c32(width); /* width */
2358+
bfe->definitions[0] = and_instr->definitions[0];
2359+
bfe->pass_flags = and_instr->pass_flags;
2360+
2361+
/* fix uses & SSA ---------------------------------------------- */
2362+
ctx.uses[and_instr->operands[0].tempId()]--;
2363+
and_instr = std::move(bfe);
2364+
ctx.info[and_instr->definitions[0].tempId()].parent_instr = and_instr.get();
2365+
return true;
2366+
}
2367+
21122368
/* creates v_lshl_add_u32, v_lshl_or_b32 or v_and_or_b32 */
21132369
bool
21142370
combine_add_or_then_and_lshl(opt_ctx& ctx, aco_ptr<Instruction>& instr)
@@ -3248,51 +3504,58 @@ static inline bool can_use_inline_constant(GfxEnum /*gfx_level*/,
32483504
}
32493505

32503506
static inline bool
3251-
is_literal_valid_for_vop3p_vega(aco_opcode op, const Operand& lit)
3507+
is_literal_valid_for_vop3p_vega(aco_opcode op, const Operand& lit) noexcept
32523508
{
32533509
if (op == aco_opcode::v_pk_mad_i16)
32543510
return false;
32553511

3256-
if (op == aco_opcode::v_pk_fma_f16 || op == aco_opcode::v_pk_mad_u16)
3512+
if (op == aco_opcode::v_pk_fma_f16 || op == aco_opcode::v_pk_mad_u16) {
32573513
return lit.isConstant() &&
32583514
can_use_inline_constant(GFX9, lit.constantValue());
3515+
}
32593516

32603517
return true;
32613518
}
32623519

3263-
static void
3264-
propagate_swizzles_vega(VALU_instruction* instr,
3265-
bool opsel_lo,
3266-
bool opsel_hi)
3520+
static void propagate_swizzles_vega(VALU_instruction* v,
3521+
bool opsel_lo,
3522+
bool opsel_hi) noexcept
32673523
{
3524+
if (!v) return; /* Safety check */
3525+
32683526
constexpr unsigned N = 3;
32693527

3528+
/* Whole-word flip: both halves come from former "hi" ------------- */
32703529
if (opsel_lo && opsel_hi) {
32713530
for (unsigned s = 0; s < N; ++s) {
3272-
const bool d_op = instr->opsel_lo[s] ^ instr->opsel_hi[s];
3273-
instr->opsel_lo[s] ^= d_op;
3274-
instr->opsel_hi[s] ^= d_op;
3275-
3276-
const bool d_ng = instr->neg_lo[s] ^ instr->neg_hi[s];
3277-
instr->neg_lo[s] ^= d_ng;
3278-
instr->neg_hi[s] ^= d_ng;
3531+
/* Use XOR swap for bitfields since they support it efficiently */
3532+
const bool d1 = v->opsel_lo[s] ^ v->opsel_hi[s];
3533+
v->opsel_lo[s] ^= d1;
3534+
v->opsel_hi[s] ^= d1;
3535+
3536+
const bool d2 = v->neg_lo[s] ^ v->neg_hi[s];
3537+
v->neg_lo[s] ^= d2;
3538+
v->neg_hi[s] ^= d2;
32793539
}
32803540
return;
32813541
}
32823542

3283-
const bool hi_to_lo = opsel_lo;
3284-
const bool lo_to_hi = !opsel_hi;
3543+
/* Partial swizzle cases ------------------------------------------ */
3544+
const bool hi_to_lo = opsel_lo; /* move hi → lo */
3545+
const bool lo_to_hi = !opsel_hi; /* move lo → hi */
32853546

32863547
for (unsigned s = 0; s < N; ++s) {
3287-
const bool src_lo = instr->opsel_lo[s];
3288-
const bool src_hi = instr->opsel_hi[s];
3289-
instr->opsel_lo[s] = hi_to_lo ? src_hi : src_lo;
3290-
instr->opsel_hi[s] = lo_to_hi ? src_lo : src_hi;
3291-
3292-
const bool n_lo = instr->neg_lo[s];
3293-
const bool n_hi = instr->neg_hi[s];
3294-
instr->neg_lo[s] = hi_to_lo ? n_hi : n_lo;
3295-
instr->neg_hi[s] = lo_to_hi ? n_lo : n_hi;
3548+
const bool orig_opsel_lo = v->opsel_lo[s];
3549+
const bool orig_opsel_hi = v->opsel_hi[s];
3550+
const bool orig_neg_lo = v->neg_lo[s];
3551+
const bool orig_neg_hi = v->neg_hi[s];
3552+
3553+
/* Apply conditional assignments based on swizzle flags */
3554+
v->opsel_lo[s] = hi_to_lo ? orig_opsel_hi : orig_opsel_lo;
3555+
v->opsel_hi[s] = lo_to_hi ? orig_opsel_lo : orig_opsel_hi;
3556+
3557+
v->neg_lo[s] = hi_to_lo ? orig_neg_hi : orig_neg_lo;
3558+
v->neg_hi[s] = lo_to_hi ? orig_neg_lo : orig_neg_hi;
32963559
}
32973560
}
32983561

@@ -4157,14 +4420,17 @@ combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
41574420
return;
41584421
}
41594422
}
4160-
} else if (instr->opcode == aco_opcode::v_or_b32 && ctx.program->gfx_level >= GFX9) {
4161-
if (combine_three_valu_op(ctx, instr, aco_opcode::s_or_b32, aco_opcode::v_or3_b32, "012",
4162-
1 | 2)) {
4423+
} else if (instr->opcode == aco_opcode::v_or_b32 && ctx.program->gfx_level >= GFX9) {
4424+
if (combine_three_valu_op(ctx, instr, aco_opcode::s_or_b32, aco_opcode::v_or3_b32, "012", 1 | 2)) {
41634425
} else if (combine_three_valu_op(ctx, instr, aco_opcode::v_or_b32, aco_opcode::v_or3_b32,
4164-
"012", 1 | 2)) {
4426+
"012", 1 | 2)) {
41654427
} else if (combine_add_or_then_and_lshl(ctx, instr)) {
41664428
} else if (combine_v_andor_not(ctx, instr)) {
4167-
}
4429+
} else if (combine_alignbit_b32(ctx, instr)) {
4430+
} else if (combine_alignbyte_b32(ctx, instr)) {
4431+
} else if (combine_bfi_b32(ctx, instr)) {
4432+
} else if (combine_bfe_b32(ctx, instr)) {
4433+
}
41684434
} else if (instr->opcode == aco_opcode::v_xor_b32 && ctx.program->gfx_level >= GFX10) {
41694435
if (combine_three_valu_op(ctx, instr, aco_opcode::v_xor_b32, aco_opcode::v_xor3_b32, "012",
41704436
1 | 2)) {

0 commit comments

Comments
 (0)