Skip to content

Commit 5b2ee61

Browse files
GMNGeoffreynithinsubbiah
authored andcommitted
[TKW] Add missing validation and error messages (iree-org#432)
This is a collection of small changes adding additional validation or more verbose error messages to TKW. Includes: - More verbose (or any) error messages for existing asserts and exceptions. Maybe this is too much, but the current messages frequently don't tell you what's going wrong. I think it's better to err on the side of too much information. If these get annoying for some reason we can trim them. - Transforming some builtin `KeyError` into more specific error messages with context - Making the validation if `get_custom` is passed an op an error instead. I think this more likely than not points to a bug, so it's better for it to be an error. I can't remember the specific case in which I hit it. - Adding earlier validation if there's a type mismatch between a reduce op init and return types. - Adding earlier validation if there's an issue when decomposing reduce ops and the local reduction doesn't match the accumulator reduction. - Reporting the argument that has an issue if there's a failure in decomposing reduce ops. - Printing which node had an issue if there's a failure during codegen - Validating that reduction and generated for loop have the same number of arguments. This otherwise results in a failure later on, but we can give more useful information here. I reported iree-org#384 for the bug that causes this to fire. - Validating MMA shapes. `m` has to be in `lhs` and `n` in `rhs`. Locally, I actually have much more restrictive validation that lhs had to be `[..., m, k]` and rhs `[..., n, k]`. In theory it looks like Wave is supposed to figure things out if that isn't the setup, but I never had a case where it actually worked, so it seems like you need walk some narrow path. This version is the less restrictive one though. - Reporting more information if IREE invocation fails. - Removing assumption that a reduction has users in `get_users` - Removing some unused variables and arguments - Adding a `__str__` method for IndexingContext and `__repr__` methods to `ExpansionInfo` and `ReductionInfo`. Maybe these should be data classes? - Adding some missing types to some functions One note is that I'm not really sure what the convention is for Exception types in the project, so a lot of these are just RuntimeError. That's not awesome, but I think it's still a lot more helpful than nothing. I tried to avoid `raise ... from` as in my experience these usually result in unhelpfully verbose stacks.
1 parent 2d83a4d commit 5b2ee61

File tree

12 files changed

+200
-49
lines changed

12 files changed

+200
-49
lines changed

iree/turbine/kernel/_support/indexing.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,18 @@ def __init__(self):
117117
self.frozen_subs: list[tuple[IndexSymbol, int]] = []
118118
self.unbacked_symbols: list[IndexSymbol] = []
119119

120+
def __str__(self):
121+
return (
122+
f"IndexingContext("
123+
f"subs: {self.subs}\n"
124+
f"special_subs: {self.special_subs}\n"
125+
f"shaped_bindings: {self.shaped_bindings}\n"
126+
f"dyn_dims: {self.dyn_dims}\n"
127+
f"frozen_subs: {self.frozen_subs}\n"
128+
f"unbacked_symbols: {self.unbacked_symbols}\n"
129+
")"
130+
)
131+
120132
def next_dyn_dim(self) -> IndexSymbol:
121133
s = index_symbol(f"D{len(self.dyn_dims)}")
122134
self.dyn_dims.append(s)
@@ -157,7 +169,7 @@ def _bind_symbol(self, symbol: IndexSymbol, value: int):
157169
self.subs[symbol] = value
158170

159171
def finalize(self):
160-
assert len(self.frozen_subs) == 0
172+
assert len(self.frozen_subs) == 0, f"{self.frozen_subs=}"
161173
# Go over everything we know and bind all free symbols.
162174
for _sb in self.shaped_bindings.values():
163175
for i in range(_sb.shaped_type.rank):

iree/turbine/kernel/compiler/kernel_codegen.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,13 @@ def __init__(self, sig: KernelSignature, entry_block: Block):
346346
}
347347

348348
def resolve_by_reference(self, reference: Any) -> Value:
349-
binding = self._bindings_by_reference[reference]
349+
try:
350+
binding = self._bindings_by_reference[reference]
351+
except KeyError:
352+
pretty = "\n".join(
353+
f"{k}: {v}" for k, v in self._bindings_by_reference.items()
354+
)
355+
raise KeyError(f"{reference} not in signature:\n{pretty}")
350356
return self.resolve(binding)
351357

352358
@abstractmethod

iree/turbine/kernel/ops/wave_ops.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -354,8 +354,7 @@ def new_function(*args: Any, **kwargs: dict[str, Any]):
354354
def get_custom(node: fx.Node) -> "CustomOp":
355355
"""Get the corresponding CustomOp for a given fx.Node."""
356356
if isinstance(node, CustomOp):
357-
print("Careful! You passed a custom op where an fx.Node was required.")
358-
return node
357+
raise ValueError(f"fx.Node required but got custom op {node}")
359358
if not isinstance(node, fx.Node):
360359
raise ValueError(f"Expected an fx.Node but got {type(node)}")
361360

@@ -549,7 +548,7 @@ def erase(self):
549548
self.graph.erase_node(self.fx_node)
550549

551550
@classmethod
552-
def handle(cls, graph, *args, **kwargs) -> fx.Node:
551+
def handle(cls, graph: RegionGraph, *args, **kwargs) -> fx.Node:
553552
node = cls(*args, **kwargs)
554553
node._add_proxy_to_graph(graph)
555554
node.fx_node.node.tkw_op = cls
@@ -1407,7 +1406,12 @@ class Reduction(NestedRegionOp):
14071406
implicit_captures: Sequence[fx.Proxy]
14081407

14091408
@classmethod
1410-
def handle(cls, graph, *args, **kwargs):
1409+
def handle(cls, graph: RegionGraph, *args, **kwargs):
1410+
if not isinstance(graph, RegionGraph):
1411+
raise TypeError(
1412+
f"handle expected {RegionGraph.__name__} but got {type(graph)}"
1413+
)
1414+
14111415
def wrapper(f):
14121416
with graph.subtracer() as subtracer:
14131417
subgraph_name, implicit_captures = subtracer.trace(f)
@@ -1689,8 +1693,14 @@ class GetResult(CustomOp):
16891693
res_idx: int
16901694

16911695
def infer_type(self):
1692-
src_type = get_custom(self.value).type
1696+
op = get_custom(self.value)
1697+
src_type = op.type
16931698
if isinstance(src_type, list):
1699+
if self.res_idx >= len(src_type):
1700+
raise RuntimeError(
1701+
f"GetResult of {self.res_idx} from result with {len(src_type)} results"
1702+
f"\n{op=}\nsrc={self.value}\n{src_type=}"
1703+
)
16941704
self.type = src_type[self.res_idx]
16951705
else:
16961706
self.type = src_type
@@ -1703,7 +1713,7 @@ def indexing_dims(self) -> list[IndexExpr]:
17031713
)
17041714
src_indexing = get_custom(self.value).indexing_dims
17051715
if has_multiple_value(src_indexing):
1706-
assert self.res_idx <= len(src_indexing) - 1
1716+
assert self.res_idx < len(src_indexing), f"{self=}"
17071717
src_indexing = src_indexing[self.res_idx]
17081718
assert is_valid_indexing_dim(src_indexing)
17091719
return src_indexing
@@ -1715,11 +1725,11 @@ def index(self) -> dict[IndexSymbol, IndexSequence]:
17151725
if custom_index is None:
17161726
return None
17171727
if not isinstance(custom, Reduction):
1718-
return custom.index
1728+
return custom_index
17191729
assert isinstance(custom_index, Sequence) and self.res_idx < len(
17201730
custom.indexing_dims
1721-
)
1722-
return custom.index[self.res_idx]
1731+
), f"Invalid {custom_index=} with {self.res_idx=} and {custom.indexing_dims=}\n{custom}"
1732+
return custom_index[self.res_idx]
17231733

17241734
@index.setter
17251735
def index(self, value: dict[IndexSymbol, IndexSequence]):
@@ -1882,6 +1892,15 @@ def infer_type(self):
18821892
reduced_dims = [dims for dims in src_type.symbolic_shape if dims != self.dim]
18831893
dst_type = Register[(*reduced_dims, src_type.dtype)]
18841894
self.type = dst_type
1895+
if (
1896+
self.init is not None
1897+
and get_custom(self.init).type.symbolic_shape != self.type.symbolic_shape
1898+
):
1899+
raise RuntimeError(
1900+
f"Init type for {self.tkw_op_name} {get_custom(self.init).type.symbolic_shape}"
1901+
f" must match reduce type {self.type.symbolic_shape}"
1902+
f"\n{self}"
1903+
)
18851904

18861905
@property
18871906
def num_reduction_dims(self) -> int:

iree/turbine/kernel/wave/analysis/index_sequence_analysis.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def verify_nodes(trace: CapturedTrace, constraints: list[Constraint]):
110110
continue
111111
if isinstance(custom, (Output, NestedRegionOp)):
112112
continue
113-
assert custom.index, f"Index not set for node {custom.fx_node}"
113+
assert custom.index, f"Index not set for node {custom.fx_node}: {custom}"
114114
if not custom.vector_shapes:
115115
# If vector_shapes is not set, see if it can be derived from the hardware constraints.
116116
hw_constraint = get_hardware_constraint(constraints)
@@ -121,7 +121,9 @@ def verify_nodes(trace: CapturedTrace, constraints: list[Constraint]):
121121
custom.vector_shapes = {}
122122
for dim in update_vector_shapes:
123123
custom.vector_shapes[dim] = hw_constraint.vector_shapes[dim]
124-
assert custom.vector_shapes, f"Vector shapes not set for node {custom.fx_node}"
124+
assert (
125+
custom.vector_shapes
126+
), f"Vector shapes not set for node {custom.fx_node}: {custom}"
125127

126128

127129
def set_node_indices(
@@ -685,7 +687,13 @@ def apply_offset(node: fx.Node):
685687
return False
686688
for dim, scale in custom.expanded_dims.items():
687689
if dim in custom.index:
688-
custom.index[dim].start += scale * custom.vector_shapes[dim]
690+
try:
691+
custom.index[dim].start += scale * custom.vector_shapes[dim]
692+
except KeyError as e:
693+
raise RuntimeError(
694+
f"op index or vector shapes missing expanded dim {dim}:\n"
695+
f"{custom.index}\n{custom.vector_shapes}\n{custom}"
696+
)
689697
return False
690698

691699
trace.walk(apply_offset)
@@ -741,8 +749,25 @@ def get_index(custom: CustomOp):
741749
lhs = get_custom(custom.lhs)
742750
rhs = get_custom(custom.rhs)
743751

744-
lhs_dim, lhs_size = get_largest_index_and_size(get_index(lhs))
745-
rhs_dim, rhs_size = get_largest_index_and_size(get_index(rhs))
752+
lhs_index = get_index(lhs)
753+
rhs_index = get_index(rhs)
754+
755+
lhs_dim, lhs_size = get_largest_index_and_size(lhs_index)
756+
rhs_dim, rhs_size = get_largest_index_and_size(rhs_index)
757+
758+
extra_error_info = (
759+
f"\n{binary_op=}"
760+
f"\n{lhs=}"
761+
f"\n{lhs_index=}"
762+
f"\n{lhs_dim=}"
763+
f"\n{lhs_size=}"
764+
f"\n{lhs.type.symbolic_shape=}"
765+
f"\n{rhs=}"
766+
f"\n{rhs_index=}"
767+
f"\n{rhs_dim=}"
768+
f"\n{rhs_size=}"
769+
f"\n{rhs.type.symbolic_shape=}"
770+
)
746771

747772
# If they are equal we are done.
748773
if lhs_dim == rhs_dim and lhs_size == rhs_size:
@@ -753,7 +778,8 @@ def get_index(custom: CustomOp):
753778
# Cannot handle discrepancies when both shapes are > 1.
754779
if lhs_size > 1 and rhs_size > 1:
755780
raise NotImplementedError(
756-
"Currently only support resolving discrepancies when one of the shapes is 1."
781+
f"Currently only support resolving discrepancies when one of the shapes is 1."
782+
f"{extra_error_info}"
757783
)
758784

759785
broadcast_rhs = lhs_size > rhs_size
@@ -774,7 +800,9 @@ def get_index(custom: CustomOp):
774800

775801
if not is_only_missing_dim and not is_innermost_dim:
776802
raise NotImplementedError(
777-
"Currently only support resolving discrepancies when the broadcasting dimension is the innermost dimension."
803+
f"Currently only support resolving discrepancies when the broadcasting"
804+
f" dimension is the innermost dimension. {extra_error_info}"
805+
f"\n{broadcast_dim=}"
778806
)
779807

780808
# Broadcast

iree/turbine/kernel/wave/codegen/emitter.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Any, Callable, ClassVar, Optional, List, Type, Dict
99
from dataclasses import dataclass
1010
from collections import namedtuple
11+
import sys
1112

1213
import torch.fx as fx
1314

@@ -112,7 +113,11 @@ def _emit_function_call_node(self, node: fx.Node):
112113
except KeyError:
113114
raise CodegenError(f"No handler registered for op {target_op}")
114115

115-
handler(self, node)
116+
try:
117+
handler(self, node)
118+
except:
119+
print(f"Error handling {node}", file=sys.stderr)
120+
raise
116121

117122
def lookup_node_values(self, node: fx.Node) -> List[Value]:
118123
assert NDEBUG or isinstance(node, fx.Node)

iree/turbine/kernel/wave/codegen/handlers.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -391,9 +391,15 @@ def handle_generic_binary(emitter: WaveEmitter, node: fx.Node):
391391
rhs = cast_py_value(emitter, rhs)
392392

393393
if lhs.ir_value.type != rhs.ir_value.type:
394+
op = get_custom(node)
394395
raise ValidationError(
395-
"Expected lhs and rhs to have same type."
396-
f" Got: {lhs.ir_value.type} vs {rhs.ir_value.type}"
396+
f"Expected lhs and rhs to have same type for\n"
397+
f"{op}\nGot\n"
398+
f"lhs: {lhs.ir_value.type} vs rhs: {rhs.ir_value.type}\n"
399+
f"{lhs=}\n"
400+
f"{rhs=}\n"
401+
f"lhs={get_custom(op.lhs)}\n"
402+
f"rhs={get_custom(op.rhs)}"
397403
)
398404

399405
lhs = lhs.ir_value
@@ -768,6 +774,11 @@ def handle_reduction(emitter: WaveEmitter, node: fx.Node):
768774
# Add mapping for iter args.
769775
subgraph: fx.Graph = emitter.trace.get_subgraph(subgraph)
770776
iter_args: list[fx.Node] = get_custom(node).iter_args(subgraph)
777+
assert len(iter_args) == len(forOp.inner_iter_args), (
778+
f"Len of reduction and for op iter args must match,"
779+
f" Reduction args: {iter_args};"
780+
f" For Op args: {[a.type for a in forOp.inner_iter_args]}"
781+
)
771782
for i, v in enumerate(forOp.inner_iter_args):
772783
emitter.bind_node_proxy(iter_args[i], IRProxyValue(v))
773784
captured_vars: list[fx.Node] = get_custom(node).captured_vars(subgraph)
@@ -785,6 +796,12 @@ def handle_reduction(emitter: WaveEmitter, node: fx.Node):
785796
flat_ret_values = [
786797
cast_py_value(emitter, value).ir_value for value in flat_ret_values
787798
]
799+
assert len(flat_ret_values) == len(flat_init_args), (
800+
f"Loop must have the same number of return values as init args, but got\n"
801+
f"{len(flat_ret_values)} vs {len(flat_init_args)}\n"
802+
f"{flat_ret_values=}\n"
803+
f"{flat_init_args=}\n"
804+
)
788805
scf_d.YieldOp(flat_ret_values)
789806

790807
emitter.bind_node_proxies(node, [IRProxyValue(v) for v in forOp.results_])
@@ -907,7 +924,7 @@ def handle_broadcast(emitter: WaveEmitter, node: fx.Node):
907924
raise NotImplementedError("Scalar src is not implemented yet for shuffleOp.")
908925
assert (
909926
vector_type.rank == 0 or vector_type.rank == 1
910-
), f"expected vector_type.rank == 1 but got {vector_type}"
927+
), f"expected vector_type.rank == 1 but got {vector_type}, {node}"
911928

912929
# Handles scalar broadcast case.
913930
if vector_type.rank == 0:

iree/turbine/kernel/wave/decompose_reduce_ops.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def determine_shuffle_config(
8080
return cluster_size, cluster_stride[0]
8181

8282

83-
def get_graph_node(custom: CustomOp, graph: fx.Graph):
83+
def get_graph_node(custom: CustomOp, graph: fx.Graph) -> fx.Node:
8484
custom.add_to_graph(graph)
8585
custom = custom.fx_node
8686
return custom
@@ -117,7 +117,7 @@ def emit_local_reduction(
117117
reduction_src: list[fx.Node],
118118
graph: fx.Graph,
119119
local_reduction_size,
120-
):
120+
) -> fx.Node:
121121
"""
122122
Does reduction over all the element carried along by ReductionOp at local
123123
thread/SIMT level. This is done by reducing expanded sources combining them
@@ -135,7 +135,7 @@ def emit_scalarized_local_reduction(
135135
reduction_src: list[fx.Node],
136136
graph: fx.Graph,
137137
local_reduction_size,
138-
):
138+
) -> fx.Node:
139139
"""
140140
Special case of local reduction wher we try to scalarize/get rid of most vector ops.
141141
this is useful for maximum, to expose more opportunities for v_max3_f32,
@@ -184,7 +184,6 @@ def emit_global_reduction(
184184
def decompose_reduce_ops(
185185
trace: CapturedTrace,
186186
constraints: list[Constraint],
187-
index_map: dict[IndexSymbol, int],
188187
):
189188
"""
190189
The lowering for multi_reduction is done in two steps:
@@ -205,11 +204,6 @@ def decompose_reduce_ops(
205204
hardware_constraint = next(
206205
c for c in constraints if isinstance(c, HardwareConstraint)
207206
)
208-
constraint_tile_size = {
209-
c.dim: c.tile_size
210-
for c in constraints
211-
if isinstance(c, TilingConstraint) or isinstance(c, WorkgroupConstraint)
212-
}
213207
induction_vars = [
214208
c.induction_var for c in constraints if isinstance(c, TilingConstraint)
215209
]
@@ -242,9 +236,20 @@ def decompose_reduce_ops(
242236
get_thread_shape = lambda index: max(
243237
subs_idxc(x.size) for x in index.values()
244238
)
245-
local_reduce_sizes = [
246-
get_thread_shape(get_custom(arg).index) for arg in reduction_src
247-
]
239+
local_reduce_sizes = []
240+
for arg in reduction_src:
241+
try:
242+
op = get_custom(arg)
243+
244+
thread_shape = get_thread_shape(op.index)
245+
local_reduce_sizes.append(thread_shape)
246+
except Exception as e:
247+
index_str = "\n".join(f"{k}: {v}" for k, v in op.index.items())
248+
raise RuntimeError(
249+
f"Error in decompose_reduce_ops: {arg} with index\n"
250+
f"{index_str}\n{reduction_src=}\n{reduction_acc=}\n{reduction_dim=}"
251+
) from e
252+
248253
if not all_equal(local_reduce_sizes):
249254
raise NotImplementedError(
250255
"NYI: Expect all reduce_src to have same local reduce size."
@@ -258,6 +263,19 @@ def decompose_reduce_ops(
258263
binary_fn, reduction_src, custom.graph, local_reduce_sizes[0]
259264
)
260265

266+
if (
267+
reduction_acc is not None
268+
and get_custom(local_reduction).type.symbolic_shape
269+
!= get_custom(reduction_acc).type.symbolic_shape
270+
):
271+
raise RuntimeError(
272+
"Local reduction and accumulator reduction must have same shape."
273+
f"\nlocal_reduction: {get_custom(local_reduction).type.symbolic_shape}"
274+
f"\nreduction_acc: {get_custom(reduction_acc).type.symbolic_shape}"
275+
f"\nlocal_reduction: {get_custom(local_reduction)}"
276+
f"\nreduction_acc: {get_custom(reduction_acc)}"
277+
f"\n{custom}"
278+
)
261279
# Global Reduce
262280
cluster_size, cluster_stride = determine_shuffle_config(
263281
reduction_src[0].index,

0 commit comments

Comments
 (0)