3232 ** Global offset (Index Sequence unit : elements)
3333 global offset is calculated by only perserving global index sequence with wave index sequence.
3434 this is valid because tensor load instruction expects global base address of a tile.
35- ** Shared offset (Allocation size unit: bytes)
36- shared offset is calculated by materializing the distributed shape from a "write_shared" node.
35+ ** Shared offset (Index Sequence unit: elements)
36+ shared offset is calculated by preserving the index sequence from a "write_shared" node,
37+ removing thread offsets within a tile, similar to global offset.
3738
3839Example:
3940For loading tensors with shape M x K to alloc0 (smem), and N x K to alloc1 (smem),
4041with tile size = BLOCK_M * BLOCK_K, BLOCK_N x BLOCK_K, and K is the contiguous dimension:
41- - global offset perserves BLOCK index and WAVE_ID: $WG0 * BLOCK_M + BLOCK_M * ($T0 // 32)
42- - shared offset:
43- - for alloc0 = 0
44- - for alloc1 = BLOCKM * BLOCK_K
42+ - global offset preserves BLOCK index and WAVE_ID: $WG0 * BLOCK_M + BLOCK_M * ($T0 // 32)
43+ - shared offset preserves tile-level index: similar structure to global offset
4544"""
4645
4746import logging
7776 get_hardware_constraint ,
7877 infer_dim ,
7978 is_pow2 ,
79+ remove_global_indexing ,
8080)
8181from .utils .symbol_utils import subs_idxc
8282
83- from .memory_analysis .minimize_shared_allocs import get_alloc_info
84- from .memory_analysis .solver import determine_allocations_offsets
85-
8683logger = logging .getLogger (__name__ )
8784
8885
@@ -115,7 +112,7 @@ class TensorLoadConfig:
115112 """
116113 element_type
117114 tensor_tile_shapes : [tile dim 0 shape, tile dim 1 shape]
118- shared_tile_index (bytes )
115+ shared_tile_index (IndexSequence )
119116 global_tile_index (IndexSequence)
120117 bounds
121118
@@ -124,7 +121,7 @@ class TensorLoadConfig:
124121
125122 element_type : "DataType"
126123 distributed_shape : list [IndexExpr ]
127- shared_tile_index : int
124+ shared_tile_index : dict [ IndexSymbol , IndexSequence ]
128125 global_tile_index : dict [IndexSymbol , IndexSequence ]
129126 bounds : dict [IndexSymbol , IndexExpr ]
130127
@@ -149,13 +146,18 @@ def get_global_element_offset(
149146 return {key : IndexSequence (index [key ].start , 1 , 1 ) for key in index .keys ()}
150147
151148
152- def get_shared_tile_byte_offset (node : fx .Node , alloc_offset_map ) -> int :
149+ def get_shared_element_offset (
150+ node : CustomOp , constraints : list [Constraint ], wave_subs
151+ ) -> dict [IndexSymbol , IndexSequence ]:
153152 """
154- LDS address = Shared mem buffer + tile offset in bytes
155- This function returns the tile offset.
153+ Shared memory address = shared mem buffer + tile offset
154+ This function returns the tile index by removing threads offset within a tile .
156155 """
157- offset_sym = alloc_offset_map [node .memory ]
158- return int (offset_sym )
156+ assert isinstance (node , Write ), "Expect Write custom node as caller argument"
157+ index = remove_global_indexing (node .index , constraints )
158+
159+ index = {k : v .subs (wave_subs ) for k , v in index .items ()}
160+ return {key : IndexSequence (index [key ].start , 1 , 1 ) for key in index .keys ()}
159161
160162
161163def get_tensor_load_descriptor_config (
@@ -166,7 +168,6 @@ def get_tensor_load_descriptor_config(
166168 element_type : "DataType" ,
167169 wave_subs ,
168170 hardware_constraint : "HardwareConstraint" ,
169- alloc_offset_map ,
170171) -> TensorLoadConfig :
171172 """
172173 Get the tensor to shared config for the given read and write.
@@ -190,10 +191,10 @@ def get_tensor_load_descriptor_config(
190191
191192 distributed_shape = materialize_shape (constraint_tile_size , symbolic_shape )
192193
193- # get LDS byte offset
194- shared_tile_index = get_shared_tile_byte_offset (write , alloc_offset_map )
194+ # get shared tile index
195+ shared_tile_index = get_shared_element_offset (write , constraints , wave_subs )
195196
196- # get global tile addr
197+ # get global tile index
197198 global_tile_index = get_global_element_offset (read , wave_subs )
198199
199200 return TensorLoadConfig (
@@ -269,13 +270,6 @@ def clear_padding(write: Write):
269270 custom_memory .update_arg ("distributed_shape" , tuple (new_distributed_shape ))
270271
271272
272- def get_allocation_offsets (trace ) -> dict [fx .Node , int ]:
273- allocs , _ , alloc_info = get_alloc_info (trace )
274- offsets , _ = determine_allocations_offsets (alloc_info )
275- allocs_to_offsets = {allocs [i ]: offsets [i ] for i in range (len (allocs ))}
276- return allocs_to_offsets
277-
278-
279273def tensor_load_to_shared (
280274 trace : CapturedTrace ,
281275 constraints : list [Constraint ],
@@ -286,10 +280,9 @@ def tensor_load_to_shared(
286280 1) option.use_global_to_shared is set
287281 2) target is gfx1250
288282 1. Build 1-many mapping of GLOBAL_READ: SHARED_WRITE_X ... #a
289- 2. Get shared memory allocation information.
290- 3. Build descriptors for tensor.load.to.lds.
291- 4. Replace #a with tensor_load_to_shared op.
292- 5. Update write dependencies.
283+ 2. Build descriptors for tensor.load.to.lds with proper IndexSequence offsets.
284+ 3. Replace #a with tensor_load_to_shared op.
285+ 4. Update write dependencies.
293286 """
294287 if not options .use_global_to_shared :
295288 return
@@ -337,8 +330,6 @@ def tensor_load_to_shared(
337330 _ , write = _writes [0 ]
338331 clear_padding (write )
339332
340- allocate_offsets = get_allocation_offsets (trace )
341-
342333 for reads_writes in id_to_read_write .values ():
343334 read , write = reads_writes [0 ]
344335
@@ -356,7 +347,6 @@ def tensor_load_to_shared(
356347 element_type ,
357348 wave_subs ,
358349 hardware_constraint ,
359- allocate_offsets ,
360350 )
361351
362352 if config is None :
0 commit comments