Skip to content

Commit 8141d52

Browse files
authored
[TKW] Propagate index from reduce nodes (#644)
* If we kernel doesnt have `mma`s but have reduce ops use them to determine ops indexing. * `tkw.WorkgroupConstraint(N, N, 0)` hack is no longer needed as reduce nodes inside reduction loop will automatically distributed across threads. * Fix `test_toy_online_softmax` for unaligned shapes. * Refactor constraints and introduce `DistributionConstraint` base class. --------- Signed-off-by: Ivan Butygin <[email protected]>
1 parent b37040c commit 8141d52

File tree

6 files changed

+273
-58
lines changed

6 files changed

+273
-58
lines changed

iree/turbine/kernel/wave/analysis/index_sequence_analysis.py

Lines changed: 145 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
Output,
1616
Placeholder,
1717
Read,
18+
ReduceOp,
1819
Reduction,
1920
Write,
2021
get_custom,
@@ -33,6 +34,7 @@
3334
from ..utils.general_utils import (
3435
get_hardware_constraint,
3536
get_largest_index_and_size,
37+
get_workgroup_constraints,
3638
partial,
3739
)
3840
from ..utils.mma_utils import (
@@ -145,12 +147,21 @@ def set_node_indices(
145147
print_trace(trace)
146148

147149
graph_passes = []
148-
if mma_mapping != {}:
150+
if mma_mapping:
149151
graph_passes += [
150152
partial(
151153
set_thread_dependent_index_from_mma, constraints, mma_mapping, trace
152154
)
153155
]
156+
elif reduce_mapping := get_reduce_mapping(trace, constraints):
157+
graph_passes += [
158+
partial(
159+
set_thread_dependent_index_from_reduce,
160+
constraints,
161+
trace,
162+
reduce_mapping,
163+
)
164+
]
154165
else:
155166
graph_passes += [
156167
partial(set_thread_dependent_index_from_read_write, constraints, trace)
@@ -516,9 +527,7 @@ def set_thread_dependent_index_from_read_write(
516527
assert sources, "No read nodes found in the graph."
517528

518529
visited = set()
519-
workgroup_constraints = [
520-
c for c in constraints if isinstance(c, WorkgroupConstraint)
521-
]
530+
workgroup_constraints = get_workgroup_constraints(constraints)
522531
symbolic_constraints = [c for c in constraints if isinstance(c, SymbolicAlias)]
523532
for source in sources:
524533
visited = visited.union(set([x for x in sources]))
@@ -533,6 +542,138 @@ def set_thread_dependent_index_from_read_write(
533542
)
534543

535544

545+
def get_reduce_mapping(
546+
trace: CapturedTrace, constraints: list[Constraint]
547+
) -> dict[ReduceOp, dict[IndexSymbol, IndexSequence]]:
548+
"""
549+
Get the mapping of the reduce ops to the index sequence.
550+
551+
Resulting index will have reduction dim distributed across wg0 threads and
552+
rest of the dims distributed similar to read/write nodes according to the
553+
WorkgroupConstraints.
554+
555+
Example:
556+
```
557+
constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 1)]
558+
...
559+
@tkw.reduction(N, init_args=[init_max, init_sum])
560+
def repeat(
561+
partial_max: tkl.Register[M, tkl.f32],
562+
) -> tkl.Register[M, tkl.f32]:
563+
res = tkw.read(a) # [M, N]
564+
partial_max = tkw.max(res, partial_max, dim=N) # {N: 2*$T0 : 2 : 1, M: $T1 : 1 : 1}
565+
...
566+
```
567+
568+
"""
569+
sources = trace.walk(lambda node: isinstance(get_custom(node), ReduceOp))
570+
hardware_constraint = get_hardware_constraint(constraints)
571+
workgroup_constraints = get_workgroup_constraints(constraints)
572+
573+
reduce_mapping = {}
574+
for source in sources:
575+
custom = get_custom(source)
576+
index = {}
577+
578+
dim = custom.dim
579+
580+
# Compute the index sequence for the reduction dimension based on the
581+
# threads per wave and the vector size.
582+
threads_per_wave = hardware_constraint.threads_per_wave
583+
vector_size = hardware_constraint.vector_shapes[dim]
584+
assert (
585+
vector_size % threads_per_wave == 0
586+
), f"Vector size {dim}={vector_size} must be divisible by threads per wave {threads_per_wave}"
587+
elements_per_thread = vector_size // threads_per_wave
588+
stride = compute_stride(
589+
custom.indexing_dims, hardware_constraint.vector_shapes, dim
590+
)
591+
index[dim] = hardware_constraint.apply_read_write_thread_mapping(
592+
dim, 0, elements_per_thread, stride
593+
)
594+
595+
for dim in custom.indexing_dims:
596+
elements_per_thread = 1
597+
stride = compute_stride(
598+
custom.indexing_dims, hardware_constraint.vector_shapes, dim
599+
)
600+
wg_constraint = [x for x in workgroup_constraints if x.dim == dim]
601+
assert (
602+
len(wg_constraint) <= 1
603+
), f"Multiple workgroup constraints for dimension {dim}"
604+
if wg_constraint:
605+
workgroup_dim = wg_constraint[0].workgroup_dim
606+
else:
607+
continue
608+
609+
index[dim] = hardware_constraint.apply_read_write_thread_mapping(
610+
dim, workgroup_dim, elements_per_thread, stride
611+
)
612+
613+
reduce_mapping[custom] = index
614+
615+
return reduce_mapping
616+
617+
618+
def populate_reduce_source_indices(
619+
node: ReduceOp,
620+
hardware_constraint: HardwareConstraint,
621+
workgroup_constraints: list[WorkgroupConstraint],
622+
index: dict[IndexSymbol, IndexSequence],
623+
):
624+
"""
625+
Populate the source indices for the reduce op.
626+
"""
627+
vector_shapes = hardware_constraint.vector_shapes
628+
ret = []
629+
if isinstance(node.arg, Sequence):
630+
ret += [(get_custom(a), index, vector_shapes) for a in node.arg]
631+
else:
632+
ret += [(get_custom(node.arg), index, vector_shapes)]
633+
634+
# Reduce args must contain index for the reduction dimension,
635+
# but init and the reduction itself does not.
636+
res_index = copy(index)
637+
del res_index[node.dim]
638+
639+
if node.init:
640+
ret += [(get_custom(node.init), res_index, vector_shapes)]
641+
642+
ret += [(node, res_index, vector_shapes)]
643+
644+
return ret
645+
646+
647+
def set_thread_dependent_index_from_reduce(
648+
constraints: Sequence[Constraint],
649+
trace: CapturedTrace,
650+
reduce_mapping: dict[ReduceOp, dict[IndexSymbol, IndexSequence]],
651+
):
652+
"""
653+
Set the thread dependent index, rooting on reduce ops.
654+
"""
655+
hardware_constraint = get_hardware_constraint(constraints)
656+
sources = trace.walk(lambda node: isinstance(get_custom(node), ReduceOp))
657+
sources = [get_custom(x) for x in sources]
658+
assert sources, "No reduce nodes found in the graph."
659+
660+
visited = set()
661+
workgroup_constraints = get_workgroup_constraints(constraints)
662+
symbolic_constraints = [c for c in constraints if isinstance(c, SymbolicAlias)]
663+
for source in sources:
664+
visited = visited.union(set([x for x in sources]))
665+
visited.remove(source)
666+
index = reduce_mapping[source]
667+
new_sources = populate_reduce_source_indices(
668+
source, hardware_constraint, workgroup_constraints, index
669+
)
670+
visited = propagate_indices(
671+
new_sources,
672+
visited,
673+
symbolic_constraints,
674+
)
675+
676+
536677
def set_post_expansion_indices(trace: CapturedTrace, constraints: list[Constraint]):
537678
"""
538679
Add offsets to the indices based on the expanded dims.

iree/turbine/kernel/wave/constraints.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,24 @@ def apply(self) -> IndexSequence:
8282
...
8383

8484

85+
@dataclass
86+
class DistributionConstraint(Constraint):
87+
"""
88+
Base class for constraints that distribute a dimension across a
89+
workgroup or reduction loop.
90+
"""
91+
92+
@property
93+
def work_bound(self) -> IndexExpr:
94+
"""
95+
Returns the work bound for the constraint.
96+
97+
It may be different from the dimension of the tensor if the dimensions is not divisible
98+
by the tile size.
99+
"""
100+
raise NotImplementedError("Subclasses must implement this method")
101+
102+
85103
@dataclass
86104
class HardwareConstraint(Constraint):
87105
"""
@@ -263,18 +281,6 @@ def subs_vector_shapes(self, index_map: dict[IndexSymbol, int]):
263281
if isinstance(vector_size, IndexExpr):
264282
self.vector_shapes[vector_dim] = vector_size.subs(index_map)
265283

266-
def compute_access_pattern_using_vector_shapes(
267-
self,
268-
dim: IndexSymbol,
269-
workgroup_dim: int,
270-
elements_per_thread: int | IndexSymbol,
271-
stride: int,
272-
) -> IndexSequence:
273-
thread_id = self.get_thread_id_from_workgroup_dim(workgroup_dim)
274-
return IndexSequence(
275-
thread_id * elements_per_thread, elements_per_thread, stride
276-
)
277-
278284
def apply(self):
279285
assert False, "Call either apply_read_write_thread_mapping or apply_mma_mapping"
280286

@@ -370,7 +376,7 @@ def apply_mma_mapping(
370376

371377

372378
@dataclass
373-
class WorkgroupConstraint(Constraint):
379+
class WorkgroupConstraint(DistributionConstraint):
374380
"""
375381
A constraint of the form `tkw.WorkgroupConstraint(M, BLOCK_M, 0)`
376382
specifies that we want to distribute dimension M along workgroup dim 0
@@ -410,6 +416,10 @@ def apply(self) -> IndexSequence:
410416
return IndexSequence(self.apply_fn(self.wg_dim), 1)
411417
return IndexSequence(self.wg_dim * self.tile_size, 1)
412418

419+
@property
420+
def work_bound(self) -> IndexExpr:
421+
return self.count * self.tile_size
422+
413423

414424
def get_grid_shape(wg_constraints: list[WorkgroupConstraint]) -> list[IndexExpr]:
415425
sorted_constraints = sorted(
@@ -428,7 +438,7 @@ def get_grid_shape(wg_constraints: list[WorkgroupConstraint]) -> list[IndexExpr]
428438

429439

430440
@dataclass
431-
class TilingConstraint(Constraint):
441+
class TilingConstraint(DistributionConstraint):
432442
"""
433443
A constraint of the form `tkw.TilingConstraint(K, BLOCK_K)` specifies
434444
that we want to tile the K dimension with a tile size of BLOCK_K. This
@@ -469,6 +479,10 @@ def apply(self) -> IndexSequence:
469479
)
470480
return IndexSequence(self.start + self.induction_var * self.tile_size, 1)
471481

482+
@property
483+
def work_bound(self) -> IndexExpr:
484+
return self.start + self.count * self.tile_size
485+
472486

473487
@dataclass
474488
class WaveConstraint(Constraint):

iree/turbine/kernel/wave/utils/general_utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from ..assumptions import Assumption
1919
from ..constraints import (
2020
Constraint,
21+
DistributionConstraint,
2122
HardwareConstraint,
2223
TilingConstraint,
2324
WorkgroupConstraint,
@@ -144,10 +145,10 @@ def align_index_vars(
144145
need partial reads/writes.
145146
"""
146147
key_subs = {
147-
c.dim: (c.count * c.tile_size)
148+
c.dim: (c.work_bound)
148149
for c in constraints
149-
if isinstance(c, (TilingConstraint, WorkgroupConstraint))
150-
and subs_idxc(c.dim) != subs_idxc(c.count * c.tile_size)
150+
if isinstance(c, DistributionConstraint)
151+
and subs_idxc(c.dim) != subs_idxc(c.work_bound)
151152
}
152153
return {safe_subs(key, key_subs): index[key] for key in index}
153154

@@ -157,14 +158,14 @@ def find_index_bounds(
157158
) -> Optional[list[IndexExpr]]:
158159
bounds = []
159160
for constraint in constraints:
160-
if not isinstance(constraint, (WorkgroupConstraint, TilingConstraint)):
161+
if not isinstance(constraint, DistributionConstraint):
161162
continue
162163

163164
dim = constraint.dim
164165
if dim not in index:
165166
continue
166167

167-
work_size = constraint.count * constraint.tile_size
168+
work_size = constraint.work_bound
168169
if subs_idxc(work_size) == subs_idxc(dim):
169170
continue
170171

0 commit comments

Comments
 (0)