Skip to content

Commit 194ef24

Browse files
[Wave] Basic global_load_lds support (#25)
Add some basic functional support for `global_load_lds` instructions generation. See doc for the details. Currently doesn't support buffer ops (needs llvm/llvm-project#149407), masking (need buffer ops and for Ivan to think real hard), or actual gathers (need for Ivan to think even harder). Scheduling doesn't work either still, but I had fixed some immediate issues. --------- Signed-off-by: nithinsubbiah <[email protected]> Signed-off-by: Ivan Butygin <[email protected]> Co-authored-by: nithinsubbiah <[email protected]>
1 parent 46b1867 commit 194ef24

24 files changed

+843
-132
lines changed

docs/wave/gather_to_shared.rst

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
.. _gather_to_shared:
2+
3+
Gather to Shared Memory Optimization
4+
====================================
5+
6+
Overview
7+
--------
8+
9+
The ``gather_to_shared`` pass enables direct memory loads from global memory to Local Data Store (LDS) without passing through registers, reducing data movement overhead.
10+
11+
This instruction is supported only on specific AMD GPU architectures (gfx94* and gfx95*).
12+
13+
Architecture Support
14+
--------------------
15+
16+
- **gfx94**: Support 32-bit load/store widths
17+
- **gfx95**: Support 32-bit, 96-bit, and 128-bit load/store widths
18+
19+
Both architectures also support 8 and 16 bit load widths, but they are zero/sign extended to 32 bit before store, which is not very useful for us.
20+
21+
Instruction Semantics
22+
---------------------
23+
24+
``gather_to_shared`` is translated to ``amdgpu.gather_to_lds`` MLIR op, which is lowered to ``global_load_lds_*`` instructions.
25+
26+
Each thread reads 4, 12, or 16 bytes from arbitrary positions in global memory or buffer and writes them contiguously to LDS starting from the address specified in the first thread in wave.
27+
Destination addresses in all other threads are ignored.
28+
29+
The operation is asynchronous and AMDGPU backend currently doesn't enforce any dependencies with other LDS access operations (which may be fixed in the future). Users need to manually insert ``waitcnt`` instruction to avoid data races.
30+
31+
This is handled in ``add_shared_memory_barriers`` pass. Currently, it will insert ``waitcnt`` instruction right before the ``amdgpu.lds_barrier`` instruction if it had any preceding ``amdgpu.gather_to_lds`` instructions.
32+
33+
34+
Pass Description
35+
----------------
36+
37+
``gather_to_shared`` pass works similarly to ``minimize_global_loads``, it takes number of elements that need to be transferred and then divides it by the total number of threads to determine number of elements to be transferred per thread.
38+
39+
Unlike ``minimize_global_loads`` it supports a very limited number of elements per thread and only supports a simple contiguous memory layout.
40+
41+
Also, as LDS writes are always contiguous, it doesn't support padding if the number of elements per wave crosses a row boundary and will undo any LDS padding present in this case.

docs/wave/wave.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,4 @@ For more detailed information about Wave's architecture and optimization passes,
7070
schedule_modifier
7171
fused_softmax
7272
aplp
73+
gather_to_shared

iree/turbine/kernel/lang/global_symbols.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def get_workgroup_symbol(i: int):
4040
WRITE_SHARED_DELAY = index_symbol("$WRITE_SHARED_DELAY")
4141
READ_GLOBAL_DELAY = index_symbol("$READ_GLOBAL_DELAY")
4242
WRITE_GLOBAL_DELAY = index_symbol("$WRITE_GLOBAL_DELAY")
43+
GLOBAL_TO_SHARED_DELAY = index_symbol("$GLOBAL_TO_SHARED_DELAY")
4344
MMA_DELAY = index_symbol("$MMA_DELAY")
4445
VALU_DELAY = index_symbol("$VALU_DELAY")
4546
SHUFFLE_DELAY = index_symbol("$SHUFFLE_DELAY")

iree/turbine/kernel/ops/wave_ops.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
TYPE_CHECKING,
99
Any,
1010
Callable,
11-
List,
1211
Optional,
1312
Sequence,
1413
Type,
@@ -74,7 +73,7 @@ def extract_slice(
7473
def set_wave_prio(priority: int): ...
7574

7675

77-
def shared_memory_barrier(): ...
76+
def shared_memory_barrier(wait_async_ops: bool = False): ...
7877

7978

8079
def workgroup_barrier(): ...
@@ -275,6 +274,18 @@ def select(
275274
) -> "Register": ...
276275

277276

277+
def gather_to_lds(
278+
src: Memory,
279+
dst: Memory,
280+
src_idx: dict[IndexSymbol, IndexSequence],
281+
dst_idx: dict[IndexSymbol, IndexSequence],
282+
dtype: DataType,
283+
elements_per_thread: Optional[IndexExpr | int] = None,
284+
src_mapping: Optional[IndexMapping] = None,
285+
dst_mapping: Optional[IndexMapping] = None,
286+
): ...
287+
288+
278289
def define_op(op_name: str) -> Callable[[T], T]:
279290
def decorator(cls: T) -> T:
280291
cls.tkw_op_name = op_name
@@ -1218,6 +1229,8 @@ class SharedMemoryBarrier(CustomOp):
12181229
Represents a shared memory barrier in the graph.
12191230
"""
12201231

1232+
wait_async_ops: bool = False
1233+
12211234
@property
12221235
def has_side_effects(self) -> bool:
12231236
return True
@@ -2491,3 +2504,22 @@ def indexing_dims(self) -> list[IndexExpr]:
24912504

24922505
def infer_type(self):
24932506
self.type = get_custom(_to_sequence(self.args)[0]).type
2507+
2508+
2509+
@define_op("gather_to_lds")
2510+
@dataclass
2511+
class GatherToLDS(CustomOp):
2512+
"""
2513+
Represents an instruction that performs direct load from global
2514+
to lds. Source node points to the global memory to load from
2515+
and the destination node points to shared memory.
2516+
"""
2517+
2518+
src: Memory
2519+
dst: Memory
2520+
src_idx: dict[IndexSymbol, IndexSequence]
2521+
dst_idx: dict[IndexSymbol, IndexSequence]
2522+
dtype: DataType
2523+
elements_per_thread: Optional[IndexExpr | int]
2524+
src_mapping: Optional[IndexMapping]
2525+
dst_mapping: Optional[IndexMapping]

iree/turbine/kernel/wave/barriers.py

Lines changed: 63 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,68 @@
66

77
from .utils.graph_utils import is_reduction_subgraph, is_barrier_between
88
from .._support.tracing import CapturedTrace
9-
from ..ops.wave_ops import get_custom, Read, SharedMemoryBarrier, Write, NestedRegionOp
9+
from ..ops.wave_ops import (
10+
AtomicOp,
11+
CustomOp,
12+
GatherToLDS,
13+
NestedRegionOp,
14+
Read,
15+
SharedMemoryBarrier,
16+
Write,
17+
get_custom,
18+
)
1019
from ..lang.global_symbols import SHARED_ADDRESS_SPACE
1120
import torch.fx as fx
1221
from typing import Optional
22+
from enum import Enum, auto
23+
24+
25+
class MemoryAccessType(Enum):
26+
"""Enum to classify memory access operations."""
27+
28+
NONE = auto()
29+
READ = auto()
30+
WRITE = auto()
31+
READ_WRITE = auto()
32+
33+
34+
def is_shared_memory_op(node: CustomOp) -> bool:
35+
if isinstance(node, (Read, Write, AtomicOp)):
36+
return node.memory_type.address_space == SHARED_ADDRESS_SPACE
37+
elif isinstance(node, GatherToLDS):
38+
return True
39+
40+
return False
41+
42+
43+
def get_memory_access_type(node: CustomOp) -> MemoryAccessType:
44+
if isinstance(node, Read):
45+
return MemoryAccessType.READ
46+
elif isinstance(node, Write):
47+
return MemoryAccessType.WRITE
48+
elif isinstance(node, AtomicOp):
49+
return MemoryAccessType.READ_WRITE
50+
elif isinstance(node, GatherToLDS):
51+
return MemoryAccessType.WRITE
52+
else:
53+
return MemoryAccessType.NONE
54+
55+
56+
def need_barrier(node1: CustomOp, node2: CustomOp) -> bool:
57+
access_type1 = get_memory_access_type(node1)
58+
if access_type1 == MemoryAccessType.NONE:
59+
return False
60+
access_type2 = get_memory_access_type(node2)
61+
if access_type2 == MemoryAccessType.NONE:
62+
return False
63+
64+
if access_type1 != access_type2:
65+
return True
66+
67+
if access_type1 == MemoryAccessType.READ_WRITE:
68+
return True
69+
70+
return False
1371

1472

1573
def add_shared_memory_barriers(
@@ -32,19 +90,17 @@ def add_shared_memory_barriers(
3290

3391
for node in graph.nodes:
3492
custom = get_custom(node)
35-
if (
36-
isinstance(custom, (Read, Write))
37-
and custom.memory_type.address_space == SHARED_ADDRESS_SPACE
38-
):
93+
if is_shared_memory_op(custom):
3994
if last_node is None:
4095
last_node = custom
4196
continue
42-
if type(custom) != type(last_node) and not is_barrier_between(
97+
if need_barrier(custom, last_node) and not is_barrier_between(
4398
last_node.fx_node, custom.fx_node
4499
):
100+
is_async = isinstance(last_node, GatherToLDS)
45101
# Synchronize after the write to shared memory before we read from it.
46102
with graph.inserting_before(node):
47-
SharedMemoryBarrier().add_to_graph(graph)
103+
SharedMemoryBarrier(wait_async_ops=is_async).add_to_graph(graph)
48104
last_node = custom
49105
if isinstance(custom, NestedRegionOp):
50106
last_node = add_shared_memory_barriers(

iree/turbine/kernel/wave/cache.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ def get_hash(
205205
options.use_buffer_store_ops,
206206
options.use_stride_cache_swizzle,
207207
options.use_fast_math,
208+
options.use_global_to_shared,
208209
options.minimize_shared_allocs,
209210
options.reorder_allocs,
210211
options.override_schedule,

iree/turbine/kernel/wave/codegen/handlers.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,12 @@
3030
amdgpu_d,
3131
arith_d,
3232
gpu_d,
33+
llvm_d,
3334
math_d,
3435
memref_d,
3536
rocdl_d,
3637
scf_d,
3738
vector_d,
38-
llvm_d,
3939
)
4040
from iree.turbine.aot.support.ir_utils import (
4141
_is_float_type,
@@ -1398,8 +1398,34 @@ def handle_set_wave_prio(emitter: WaveEmitter, node: fx.Node):
13981398
rocdl_d.s_setprio(prio)
13991399

14001400

1401+
def waitcnt(vmcnt: int):
1402+
"""
1403+
Create `s_waitcnt` with the specified vmcnt and all other counters set to max.
1404+
"""
1405+
1406+
# Clamp vmcnt to 6bits; a lower vmcnt will produce a conservative wait
1407+
vmCnt = min(63, vmcnt)
1408+
1409+
# Extract low and high bits and combine while setting all other bits to 1
1410+
lowBits = vmCnt & 0xF
1411+
highBits = (vmCnt >> 4) << 14
1412+
otherCnts = ~0xC00F # C00F has bits 15:14 and 3:0 set
1413+
waitValue = lowBits | highBits | otherCnts
1414+
waitValue &= 0xFFFF
1415+
1416+
rocdl_d.s_waitcnt(waitValue)
1417+
1418+
14011419
@handle_op(shared_memory_barrier)
14021420
def handle_shared_memory_barrier(emitter: WaveEmitter, node: fx.Node):
1421+
try:
1422+
(wait_async_ops,) = node.args
1423+
except ValueError as e:
1424+
raise ValidationError("Malformed arguments") from e
1425+
1426+
if wait_async_ops:
1427+
waitcnt(0)
1428+
14031429
amdgpu_d.lds_barrier()
14041430

14051431

iree/turbine/kernel/wave/codegen/read_write.py

Lines changed: 91 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# See https://llvm.org/LICENSE.txt for license information.
55
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

7+
import copy
78
import sympy
89
import functools
910
from typing import Any, Optional, Dict
@@ -34,12 +35,20 @@
3435
from ...compiler.vector_codegen import (
3536
cast_kernel_buffer,
3637
cast_py_literal,
38+
cast_py_value,
3739
cast_vector,
3840
)
3941

40-
from ...ops.wave_ops import get_custom, read, write, CustomOp
42+
from ...ops.wave_ops import (
43+
CustomOp,
44+
gather_to_lds,
45+
get_custom,
46+
read,
47+
write,
48+
)
4149

4250
from ..utils.general_utils import get_fastest_index, infer_dim
51+
from ..utils.mapping_utils import transform_index_on_mapping
4352
from ..utils.symbol_utils import safe_subs, subs_idxc
4453

4554
from ..._support.indexing import IndexingContext, IndexExpr, IndexSequence, IndexSymbol
@@ -48,10 +57,11 @@
4857

4958
from .emitter import (
5059
WaveEmitter,
51-
handle_op,
5260
add_emitter_subs,
5361
gen_sympy_index,
5462
get_constant_attr,
63+
get_type_or_element_type,
64+
handle_op,
5565
)
5666

5767

@@ -883,3 +893,82 @@ def handle_write(emitter: WaveEmitter, node: fx.Node):
883893
mask,
884894
offsets_vec,
885895
)
896+
897+
898+
@handle_op(gather_to_lds)
899+
def handle_gather_to_lds(emitter: WaveEmitter, node: fx.Node):
900+
try:
901+
(
902+
src,
903+
dst,
904+
src_idx,
905+
dst_idx,
906+
element_type,
907+
elements_per_thread,
908+
src_mapping,
909+
dst_mapping,
910+
) = node.args
911+
except ValueError as e:
912+
raise ValidationError("Malformed arguments") from e
913+
914+
element_type = IrType.parse(element_type.dtype.ir_type_asm())
915+
916+
src_symbolic_shape = _get_symbolic_shape(src)
917+
dst_symbolic_shape = _get_symbolic_shape(dst)
918+
919+
src = cast_py_value(emitter, src)
920+
dst = cast_py_value(emitter, dst)
921+
src_data_type = get_type_or_element_type(src.ir_value.type)
922+
dst_data_type = get_type_or_element_type(dst.ir_value.type)
923+
924+
if not (
925+
MemRefType.isinstance(src.ir_value.type)
926+
and MemRefType.isinstance(dst.ir_value.type)
927+
):
928+
op = get_custom(node)
929+
raise ValidationError(
930+
f"Expected src and dst to be of Memref type for\n"
931+
f"{op}\nGot\n"
932+
f"src: {src.ir_value.type}\n"
933+
f"dst: {dst.ir_value.type}\n"
934+
)
935+
936+
if src_data_type != dst_data_type:
937+
op = get_custom(node)
938+
raise ValidationError(
939+
f"Expected src and dst to have same data type for\n"
940+
f"{op}\nGot\n"
941+
f"src: {src_data_type} vs dst: {dst_data_type}\n"
942+
)
943+
944+
src = src.ir_value
945+
dst = dst.ir_value
946+
947+
if src_mapping:
948+
src_idx = transform_index_on_mapping(src_mapping, src_symbolic_shape, src_idx)
949+
if dst_mapping:
950+
dst_idx = transform_index_on_mapping(dst_mapping, dst_symbolic_shape, dst_idx)
951+
952+
store_type = VectorType.get((elements_per_thread,), element_type)
953+
954+
src_index, src_index_wg, src_index_th = _build_start_indices(emitter, src_idx)
955+
dst_index, _, _ = _build_start_indices(emitter, dst_idx)
956+
957+
if False: # TODO: Buffer stuff needs upstream fixes
958+
strides = strides_from_symbolic_shape(
959+
IndexingContext.current(), src_symbolic_shape, allow_mixed_shapes=True
960+
)
961+
strides = [gen_sympy_index(add_emitter_subs(emitter), s) for s in strides]
962+
963+
src, offset_th = _linearize_memref(src, src_index_wg, src_index_th, strides)
964+
src = _cast_buffer_and_encode_stride(src, strides, element_type, emitter)
965+
966+
src_index = [offset_th]
967+
968+
amdgpu_d.gather_to_lds(
969+
src=src,
970+
src_indices=src_index,
971+
dst=dst,
972+
dst_indices=dst_index,
973+
transfer_type=store_type,
974+
)

0 commit comments

Comments
 (0)