Skip to content

[Wave][Draft] LDS Transpose Read#1005

Draft
efric wants to merge 3 commits intoiree-org:mainfrom
efric:lds_tr_draft_cleanup
Draft

[Wave][Draft] LDS Transpose Read#1005
efric wants to merge 3 commits intoiree-org:mainfrom
efric:lds_tr_draft_cleanup

Conversation

@efric
Copy link
Member

@efric efric commented Jul 3, 2025

Overview

This PR describes the WIP implementation for supporting LDS Transpose Read in Wave. Currently, rocdl.ds.read.tr8.b64 is generated and produces numerically correct results for a single MFMA op. For convenience in testing, we are using MMAType.I32_16x16x32_I8. Matrix B is the subject of the transpose (shape 16x32 [non-K, K] is transposed to 32x16 [K, non-K]). Please see the helper test in testi8NontransposeGemm for specific test details.

Implementation

Code generation of this instruction affects both global memory load and LDS load. Without interception, Wave generates strided global reads followed by merge to vector with vector.from_elements. The data is then written to shared memory, which is shaped in a transposed layout to match the expected data flow. No transpose is then needed when data is loaded from LDS.

The current (prototype) implementation works in two stages.

1. mark_hardware_transpose_candidates pass

A new pass is introduced to identify candidates for hardware-based transpose, following a logic inspired by in-thread transpose detection. If an LDS allocation satisfies the requirements, we:

  • Mark the allocation with a hardware_transpose annotation of the appropriate type.
  • Transpose the shape of the write operation back to the original (non-transposed) shape so that the LDS allocation remains in its expected layout.

However, these modifications alone are not sufficient to produce the desired behavior. The issue is that the index sequence remains as in the original (software transpose) version—so even though the shape is correct, the affine maps generated for the access patterns would be incorrect. Without adjusting the index sequence of the shared memory allocation, we would end up writing to LDS as if it were shaped 16x32, even though the transposed LDS shape is actually 32x16 [K, N]. The index sequencing for global read is similarly modified so that the access pattern for the global memory read matches expectations. For now, the index sequence is temporarily hardcoded with the correct values in modify_index_for_full_coverage.

2. Interception of Shared Memory Read Before MFMA

In the second stage, we intercept reads from shared memory before they are passed into an MFMA op. If the allocation has been marked as a hardware transpose candidate, the read is lowered to use ds.read.tr{n}.b64 semantics. For more details, please view the instruction semantics here. For simplicity, the thread access patterns in tid_mapping_i8_16x16x32 are simplified to just support single rate 8b MFMA instructions as the name suggests. Briefly, each row (dim K) has 2 threads covering 8 (dim non-K) 8b elements each ([0:7] and [8:15] respectively) - even numbered thread ids own the first 16x8 elements and odd numbered thread ids the subsequent 16x8 elements.

Remaining Work
modify_index_for_full_coverage needs to be removed and the process of correcting the access patterns for hardware transpose loads needs to be generalized to something more sophisticated. global_to_shared_gathers.py seems like a promising pass that we can piggyback off of to get what we want. Currently, its requirements are too limiting but this can be modified to support hardware transpose path. With this, ideally we can offload some of the hacky logic in hardware_transpose.

Thread mappings may also be generalized by breaking the thread subgroups according to the instruction semantics linked above.

There is now an AMDGPU wrapper for rocdl transpose loads which will help simplify emit_hardware_transpose_intrinsic

Misc

Since this PR is WIP, there are some things that I turned off for convenience (e.g padding) which I found helpful for debugging purposes but is not necessary to produce correct results.

@efric efric changed the title [LDS Transpose Read] Draft PR (hard coded variant) [Wave][Draft] LDS Transpose Read Jul 3, 2025
@efric efric marked this pull request as draft July 4, 2025 00:00
Comment on lines +647 to +665
def emit_hardware_transpose_intrinsic(
vector_type: VectorType, stride, kb_src, kb_ir_type, hardware_constraint, emitter
) -> Value:
tid = hardware_constraint.linearized_thread_id % hardware_constraint.threads_per_wave
final_address = tid_mapping_i8_16x16x32(kb_src, tid, stride, emitter)
i64_type = IntegerType.get_signless(64)
final_address_i64 = arith_d.index_cast(i64_type, final_address)
ptr_type = llvm_d.PointerType.get(address_space=3, context=kb_ir_type.context)
llvm_ptr = llvm_d.inttoptr(ptr_type, final_address_i64)

i32_type = IntegerType.get_signless(32)
i32_vec_type = VectorType.get([2], i32_type)
packed_result = rocdl_d.ds_read_tr8_b64(i32_vec_type, llvm_ptr)

vtype = vector_type.element_type
vec8_v_type = VectorType.get([8], vtype)
result = vector_d.bitcast(vec8_v_type, packed_result)
return result

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could maybe be upstreamed as part of the amdgpu dialect.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has since been upstreamed. Updated TODO in description to include that; will make that part simpler 😄 .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants

Comments