Skip to content

Commit e1b1a85

Browse files
authored
Refactor tensor load shared mem index to use proper index expr (#450)
Use a proper symbolic index dict for shared mem access instead of flat int offset for tensor load. Depends on llvm/llvm-project#167615 --------- Signed-off-by: Ivan Butygin <[email protected]>
1 parent 050e7d2 commit e1b1a85

File tree

4 files changed

+58
-46
lines changed

4 files changed

+58
-46
lines changed

lit_tests/kernel/wave/mma.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -667,11 +667,16 @@ def mma(
667667
# CHECK: %[[D0:.*]] = vector.from_elements
668668
# CHECK: %[[TENSOR_DESC_0:.*]] = vector.from_elements
669669
# CHECK: llvm.call_intrinsic "llvm.amdgcn.tensor.load.to.lds"(%[[D0]], %[[TENSOR_DESC_0]], {{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xi32>, vector<4xi32>, i32) -> ()
670-
# CHECK-NOT: llvm.call_intrinsic "llvm.amdgcn.s.wait.tensorcnt"
671-
# CHECK-NOT: amdgpu.lds_barrier
670+
# CHECK-NOT: llvm.call_intrinsic "llvm.amdgcn.s.wait.tensorcnt"
671+
# CHECK-NOT: amdgpu.lds_barrier
672+
673+
### get shared buffer pointer
674+
# CHECK: %[[CAST_4:.*]] = memref.reinterpret_cast %[[VIEW0]]
675+
# CHECK: %[[INT_PTR_2:.+]] = memref.extract_aligned_pointer_as_index %[[CAST_4]]
676+
# CHECK: %[[INT_PTR_2_CAST:.+]] = arith.index_cast %[[INT_PTR_2]] : index to i32
672677

673678
### pack descriptors and invoke tensor load
674-
# CHECK: %[[D1:.*]] = vector.from_elements
679+
# CHECK: %[[D1:.*]] = vector.from_elements %{{.*}}, %[[INT_PTR_2_CAST]], %{{.*}}, %{{.*}} : vector<4xi32>
675680
# CHECK: %[[TENSOR_DESC_1:.*]] = vector.from_elements
676681

677682
### resource provider

wave_lang/kernel/ops/wave_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3087,7 +3087,7 @@ class TensorLoadToLDS(CustomOp):
30873087
dst: Memory
30883088
element_type: DataType
30893089
distributed_shape: list[IndexExpr]
3090-
shared_tile_index: int
3090+
shared_tile_index: dict[IndexSymbol, IndexSequence]
30913091
global_tile_index: dict[IndexSymbol, IndexSequence]
30923092
bounds: dict[IndexSymbol, IndexExpr]
30933093

wave_lang/kernel/wave/codegen/read_write.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -735,7 +735,7 @@ def handle_tensor_load_to_lds(emitter: WaveEmitter, node: fx.Node):
735735
strides = [strides[0] * symbolic_shape[0]] + strides[:-1]
736736
strides = [gen_sympy_index(subs, s) for s in strides]
737737

738-
distributed_shape = [gen_sympy_index(subs, s) for s in distributed_shape]
738+
distributed_shape_values = [gen_sympy_index(subs, s) for s in distributed_shape]
739739

740740
# construct default descriptors
741741
i32 = IntegerType.get_signless(32)
@@ -757,8 +757,8 @@ def handle_tensor_load_to_lds(emitter: WaveEmitter, node: fx.Node):
757757
valid = 1
758758
dim_stride_1 = arith_d.index_cast(i48, strides[0])
759759
dim_stride_0 = arith_d.index_cast(i48, strides[1])
760-
tile_size_1 = arith_d.index_cast(i32, distributed_shape[0])
761-
tile_size_0 = arith_d.index_cast(i32, distributed_shape[1])
760+
tile_size_1 = arith_d.index_cast(i32, distributed_shape_values[0])
761+
tile_size_0 = arith_d.index_cast(i32, distributed_shape_values[1])
762762
dim_size_1 = arith_d.index_cast(i32, local_bounds[0])
763763
dim_size_0 = arith_d.index_cast(i32, local_bounds[1])
764764

@@ -797,16 +797,33 @@ def handle_tensor_load_to_lds(emitter: WaveEmitter, node: fx.Node):
797797
global_byte_address = arith_d.addi(global_ptr, global_index_offset)
798798

799799
# calculate shared address
800-
# 0. get allocation space offset in descriptor
801-
# 1. get shared memory pointer
802-
# 2. move shared memory pointer by offset_byte to get shared memory address of a tile.
800+
# 0. extract shared tile index from IndexSequence structure
801+
# 1. calculate byte offset from tile indices and distributed shape
802+
# 2. get shared memory pointer
803+
# 3. move shared memory pointer by offset_byte to get shared memory address of a tile.
803804
shared_buffer = _linearize_shared_mem(shared_value)
804805

805-
shared_byte_offset = arith_d.constant(i32, shared_tile_index)
806+
shared_strides = strides_from_symbolic_shape(
807+
IndexingContext.current(), dst_memory.distributed_shape, allow_mixed_shapes=True
808+
)
809+
linearized_index = {
810+
"linearized_idx": linearize_index(shared_tile_index, shared_strides)
811+
}
812+
813+
# Calculate shared memory offset from tile indices
814+
shared_index, _, _ = _build_start_indices(emitter, linearized_index)
815+
816+
shared_index_offset = arith_d.muli(shared_index[0], element_byte_index)
817+
shared_byte_offset = arith_d.index_cast(i32, shared_index_offset)
806818

807819
shared_ptr = memref_d.extract_aligned_pointer_as_index(shared_buffer)
808820
shared_ptr = arith_d.index_cast(i32, shared_ptr)
809-
shared_byte_address = arith_d.addi(shared_ptr, shared_byte_offset)
821+
822+
shared_ptr_base_offset = memref_d.extract_strided_metadata(shared_buffer)[1]
823+
shared_ptr_base_offset = arith_d.index_cast(i32, shared_ptr_base_offset)
824+
825+
shared_byte_address = arith_d.addi(shared_ptr_base_offset, shared_byte_offset)
826+
shared_byte_address = arith_d.addi(shared_ptr, shared_byte_address)
810827

811828
# assume no mapping
812829
def lshift(value, bits):

wave_lang/kernel/wave/tensor_load_to_shared.py

Lines changed: 24 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,15 @@
3232
** Global offset (Index Sequence unit : elements)
3333
global offset is calculated by only perserving global index sequence with wave index sequence.
3434
this is valid because tensor load instruction expects global base address of a tile.
35-
** Shared offset (Allocation size unit: bytes)
36-
shared offset is calculated by materializing the distributed shape from a "write_shared" node.
35+
** Shared offset (Index Sequence unit: elements)
36+
shared offset is calculated by preserving the index sequence from a "write_shared" node,
37+
removing thread offsets within a tile, similar to global offset.
3738
3839
Example:
3940
For loading tensors with shape M x K to alloc0 (smem), and N x K to alloc1 (smem),
4041
with tile size = BLOCK_M * BLOCK_K, BLOCK_N x BLOCK_K, and K is the contiguous dimension:
41-
- global offset perserves BLOCK index and WAVE_ID: $WG0 * BLOCK_M + BLOCK_M * ($T0 // 32)
42-
- shared offset:
43-
- for alloc0 = 0
44-
- for alloc1 = BLOCKM * BLOCK_K
42+
- global offset preserves BLOCK index and WAVE_ID: $WG0 * BLOCK_M + BLOCK_M * ($T0 // 32)
43+
- shared offset preserves tile-level index: similar structure to global offset
4544
"""
4645

4746
import logging
@@ -77,12 +76,10 @@
7776
get_hardware_constraint,
7877
infer_dim,
7978
is_pow2,
79+
remove_global_indexing,
8080
)
8181
from .utils.symbol_utils import subs_idxc
8282

83-
from .memory_analysis.minimize_shared_allocs import get_alloc_info
84-
from .memory_analysis.solver import determine_allocations_offsets
85-
8683
logger = logging.getLogger(__name__)
8784

8885

@@ -115,7 +112,7 @@ class TensorLoadConfig:
115112
"""
116113
element_type
117114
tensor_tile_shapes : [tile dim 0 shape, tile dim 1 shape]
118-
shared_tile_index (bytes)
115+
shared_tile_index (IndexSequence)
119116
global_tile_index (IndexSequence)
120117
bounds
121118
@@ -124,7 +121,7 @@ class TensorLoadConfig:
124121

125122
element_type: "DataType"
126123
distributed_shape: list[IndexExpr]
127-
shared_tile_index: int
124+
shared_tile_index: dict[IndexSymbol, IndexSequence]
128125
global_tile_index: dict[IndexSymbol, IndexSequence]
129126
bounds: dict[IndexSymbol, IndexExpr]
130127

@@ -149,13 +146,18 @@ def get_global_element_offset(
149146
return {key: IndexSequence(index[key].start, 1, 1) for key in index.keys()}
150147

151148

152-
def get_shared_tile_byte_offset(node: fx.Node, alloc_offset_map) -> int:
149+
def get_shared_element_offset(
150+
node: CustomOp, constraints: list[Constraint], wave_subs
151+
) -> dict[IndexSymbol, IndexSequence]:
153152
"""
154-
LDS address = Shared mem buffer + tile offset in bytes
155-
This function returns the tile offset.
153+
Shared memory address = shared mem buffer + tile offset
154+
This function returns the tile index by removing threads offset within a tile.
156155
"""
157-
offset_sym = alloc_offset_map[node.memory]
158-
return int(offset_sym)
156+
assert isinstance(node, Write), "Expect Write custom node as caller argument"
157+
index = remove_global_indexing(node.index, constraints)
158+
159+
index = {k: v.subs(wave_subs) for k, v in index.items()}
160+
return {key: IndexSequence(index[key].start, 1, 1) for key in index.keys()}
159161

160162

161163
def get_tensor_load_descriptor_config(
@@ -166,7 +168,6 @@ def get_tensor_load_descriptor_config(
166168
element_type: "DataType",
167169
wave_subs,
168170
hardware_constraint: "HardwareConstraint",
169-
alloc_offset_map,
170171
) -> TensorLoadConfig:
171172
"""
172173
Get the tensor to shared config for the given read and write.
@@ -190,10 +191,10 @@ def get_tensor_load_descriptor_config(
190191

191192
distributed_shape = materialize_shape(constraint_tile_size, symbolic_shape)
192193

193-
# get LDS byte offset
194-
shared_tile_index = get_shared_tile_byte_offset(write, alloc_offset_map)
194+
# get shared tile index
195+
shared_tile_index = get_shared_element_offset(write, constraints, wave_subs)
195196

196-
# get global tile addr
197+
# get global tile index
197198
global_tile_index = get_global_element_offset(read, wave_subs)
198199

199200
return TensorLoadConfig(
@@ -269,13 +270,6 @@ def clear_padding(write: Write):
269270
custom_memory.update_arg("distributed_shape", tuple(new_distributed_shape))
270271

271272

272-
def get_allocation_offsets(trace) -> dict[fx.Node, int]:
273-
allocs, _, alloc_info = get_alloc_info(trace)
274-
offsets, _ = determine_allocations_offsets(alloc_info)
275-
allocs_to_offsets = {allocs[i]: offsets[i] for i in range(len(allocs))}
276-
return allocs_to_offsets
277-
278-
279273
def tensor_load_to_shared(
280274
trace: CapturedTrace,
281275
constraints: list[Constraint],
@@ -286,10 +280,9 @@ def tensor_load_to_shared(
286280
1) option.use_global_to_shared is set
287281
2) target is gfx1250
288282
1. Build 1-many mapping of GLOBAL_READ: SHARED_WRITE_X ... #a
289-
2. Get shared memory allocation information.
290-
3. Build descriptors for tensor.load.to.lds.
291-
4. Replace #a with tensor_load_to_shared op.
292-
5. Update write dependencies.
283+
2. Build descriptors for tensor.load.to.lds with proper IndexSequence offsets.
284+
3. Replace #a with tensor_load_to_shared op.
285+
4. Update write dependencies.
293286
"""
294287
if not options.use_global_to_shared:
295288
return
@@ -337,8 +330,6 @@ def tensor_load_to_shared(
337330
_, write = _writes[0]
338331
clear_padding(write)
339332

340-
allocate_offsets = get_allocation_offsets(trace)
341-
342333
for reads_writes in id_to_read_write.values():
343334
read, write = reads_writes[0]
344335

@@ -356,7 +347,6 @@ def tensor_load_to_shared(
356347
element_type,
357348
wave_subs,
358349
hardware_constraint,
359-
allocate_offsets,
360350
)
361351

362352
if config is None:

0 commit comments

Comments
 (0)