Skip to content

Commit b403f2e

Browse files
authored
[Ref Mode] Expand ref eager mode support to more hl.* APIs (e.g. load / store / scan / reduce) (#410)
1 parent 06f1b52 commit b403f2e

33 files changed

+1206
-190
lines changed

helion/_testing.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from typing import Generator
1616
import unittest
1717

18+
import pytest
1819
import torch
1920
from triton.testing import do_bench
2021

@@ -110,6 +111,8 @@ class RefEagerTestBase:
110111
# Class-level tracking for skipTest counting
111112
_skip_test_count = 0
112113
_original_skip_test_func = None
114+
# Class-level tracking for pytest.raises patching
115+
_original_pytest_raises = None
113116

114117
def setUp(self) -> None:
115118
"""Common setup for all ref eager tests."""
@@ -163,6 +166,18 @@ def counting_skip_test(*args: object, **kwargs: object) -> object:
163166
self._run_ref_tracker = track_run_ref_calls()
164167
self._run_ref_count = self._run_ref_tracker.__enter__()
165168

169+
# Patch pytest.raises to count calls
170+
if RefEagerTestBase._original_pytest_raises is None: # pyright: ignore[reportAttributeAccessIssue]
171+
RefEagerTestBase._original_pytest_raises = pytest.raises
172+
173+
def counting_pytest_raises(*args: object, **kwargs: object) -> object:
174+
"""Wrapper for pytest.raises that counts calls but still runs the original logic."""
175+
RefEagerTestBase._assert_raises_count += 1
176+
assert RefEagerTestBase._original_pytest_raises is not None # pyright: ignore[reportAttributeAccessIssue]
177+
return RefEagerTestBase._original_pytest_raises(*args, **kwargs) # pyright: ignore[reportAttributeAccessIssue]
178+
179+
pytest.raises = counting_pytest_raises # type: ignore[assignment]
180+
166181
def tearDown(self) -> None:
167182
"""Common teardown with assertion counting check."""
168183
# If not in ref eager mode, skip the teardown logic
@@ -215,6 +230,10 @@ def tearDown(self) -> None:
215230
if RefEagerTestBase._original_skip_test_func is not None:
216231
self.skipTest = RefEagerTestBase._original_skip_test_func
217232

233+
# Restore the original pytest.raises function
234+
if RefEagerTestBase._original_pytest_raises is not None: # pyright: ignore[reportAttributeAccessIssue]
235+
pytest.raises = RefEagerTestBase._original_pytest_raises # pyright: ignore[reportAttributeAccessIssue]
236+
218237
super().tearDown() # type: ignore[misc]
219238

220239
# NOTE: We no-op these methods because they commonly check behaviors that are not relevant in ref eager mode.
@@ -235,6 +254,10 @@ def assertNotIn(
235254
if not self._in_ref_eager_mode:
236255
super().assertNotIn(member, container, msg) # type: ignore[misc]
237256

257+
def assertTrueIfInNormalMode(self, condition: bool, msg: str | None = None) -> None:
258+
if not self._in_ref_eager_mode:
259+
self.assertTrue(condition, msg) # type: ignore[attr-defined]
260+
238261
def assertEqualCode(self, first: str, second: str, msg: str | None = None) -> None:
239262
if not self._in_ref_eager_mode:
240263
super().assertEqual(first, second, msg) # type: ignore[misc]
@@ -245,6 +268,20 @@ def assertNotEqualCode(
245268
if not self._in_ref_eager_mode:
246269
super().assertNotEqual(first, second, msg) # type: ignore[misc]
247270

271+
def getUserDefinedTunable(
272+
self, user_defined_tunables: dict[str, object], key: str
273+
) -> object | None:
274+
"""Look up a specific value via key from user defined tunables. Returns None in ref mode."""
275+
if self._in_ref_eager_mode:
276+
return None
277+
return user_defined_tunables.get(key)
278+
279+
def assertIsInstance(
280+
self, obj: object, cls: type | tuple[type, ...], msg: str | None = None
281+
) -> None:
282+
if not self._in_ref_eager_mode:
283+
super().assertIsInstance(obj, cls, msg) # type: ignore[misc]
284+
248285

249286
def import_path(filename: Path) -> types.ModuleType:
250287
module_name = f"{__name__}.{filename.stem}"

helion/_utils.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,64 @@
11
from __future__ import annotations
22

33
import collections
4+
from typing import Sequence
45

56
counters: collections.defaultdict[str, collections.Counter[str]] = (
67
collections.defaultdict(collections.Counter)
78
)
9+
10+
11+
def create_shape_matching_slices(
12+
shape1: Sequence[int], shape2: Sequence[int]
13+
) -> tuple[slice, ...]:
14+
"""Create slices to match the smaller of two shapes.
15+
16+
This is used for masking tensors to compatible shapes by taking the
17+
minimum size in each dimension.
18+
19+
Args:
20+
shape1: First shape (can be torch.Size or any sequence of ints)
21+
shape2: Second shape (can be torch.Size or any sequence of ints)
22+
23+
Returns:
24+
Tuple of slices that can be used to index a tensor
25+
"""
26+
return tuple(slice(0, min(d1, d2)) for d1, d2 in zip(shape1, shape2, strict=False))
27+
28+
29+
def convert_size_arg(size: object) -> object:
30+
"""Convert a size argument that may contain RefTile objects.
31+
32+
Handles:
33+
- Single RefTile -> int
34+
- List/tuple containing RefTiles -> list with converted sizes
35+
- Other values -> unchanged
36+
"""
37+
# Import here to avoid circular dependency
38+
from helion.language.ref_tile import RefTile
39+
40+
if isinstance(size, (list, tuple)):
41+
return [convert_size_arg(item) for item in size]
42+
if isinstance(size, RefTile):
43+
return size._slice.stop - size._slice.start
44+
return size
45+
46+
47+
def convert_tile_indices_to_slices(index: object) -> object:
48+
"""Convert RefTile objects in index to their corresponding slice objects.
49+
50+
Args:
51+
index: Index that may contain RefTile objects or tuples of indices
52+
53+
Returns:
54+
Index with RefTile objects replaced by their slice objects
55+
"""
56+
# Import here to avoid circular dependency
57+
from helion.language.ref_tile import RefTile
58+
59+
def _extract_slice(obj: object) -> object:
60+
return obj._slice if isinstance(obj, RefTile) else obj
61+
62+
if isinstance(index, tuple):
63+
return tuple(_extract_slice(idx) for idx in index)
64+
return _extract_slice(index)

helion/language/device_print.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,8 @@ def _(state: CodegenState) -> None:
9090
)
9191
stmt = create(ast.Expr, value=call_expr)
9292
state.add_statement(stmt)
93+
94+
95+
@_decorators.ref(device_print)
96+
def _(prefix: str, *values: object) -> None:
97+
print(prefix, *values)

helion/language/memory_ops.py

Lines changed: 167 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,45 @@ def _(state: CodegenState) -> ast.AST:
8484
)
8585

8686

87+
@_decorators.ref(store)
88+
def _(
89+
tensor: torch.Tensor,
90+
index: list[object],
91+
value: torch.Tensor | torch.SymInt | float,
92+
extra_mask: torch.Tensor | None = None,
93+
) -> None:
94+
# Convert index list to tuple for tensor indexing
95+
index_tuple = tuple(index)
96+
97+
# Apply extra mask if provided
98+
if extra_mask is not None:
99+
# Only store where the mask is True
100+
if isinstance(value, torch.Tensor):
101+
tensor[index_tuple] = torch.where(extra_mask, value, tensor[index_tuple]) # pyright: ignore[reportArgumentType]
102+
else:
103+
# For scalar values, we need to create a tensor of the right shape
104+
current = tensor[index_tuple] # pyright: ignore[reportArgumentType]
105+
# Cast value to a proper numeric type for full_like
106+
if isinstance(value, torch.SymInt):
107+
numeric_value = int(value)
108+
else:
109+
numeric_value = value
110+
tensor[index_tuple] = torch.where( # pyright: ignore[reportArgumentType]
111+
extra_mask, torch.full_like(current, numeric_value), current
112+
)
113+
else:
114+
# Handle SymInt case for assignment
115+
if isinstance(value, torch.SymInt):
116+
tensor[index_tuple] = int(value) # pyright: ignore[reportArgumentType]
117+
else:
118+
tensor[index_tuple] = value # pyright: ignore[reportArgumentType]
119+
120+
87121
@_decorators.api(tiles_as_sizes=True, allow_host_tensor=True)
88122
def load(
89-
tensor: torch.Tensor, index: list[object], extra_mask: torch.Tensor | None = None
123+
tensor: torch.Tensor,
124+
index: list[object],
125+
extra_mask: torch.Tensor | None = None,
90126
) -> torch.Tensor:
91127
"""Load a value from a tensor using a list of indices.
92128
@@ -129,6 +165,83 @@ def _(node: torch.fx.Node) -> int:
129165
return 0 # loads are always masked to 0
130166

131167

168+
@_decorators.ref(load)
169+
def _(
170+
tensor: torch.Tensor,
171+
index: list[object],
172+
extra_mask: torch.Tensor | None = None,
173+
) -> torch.Tensor:
174+
from .ref_tile import RefTile
175+
176+
if extra_mask is None:
177+
return tensor[tuple(index)] # pyright: ignore[reportArgumentType]
178+
179+
# Create zero result matching mask shape
180+
result = torch.zeros(extra_mask.shape, dtype=tensor.dtype, device=tensor.device)
181+
182+
# Process indices: convert RefTiles and clamp tensor indices
183+
orig_indices, safe_indices, is_tensor_mask = [], [], []
184+
for i, idx in enumerate(index):
185+
if isinstance(idx, RefTile):
186+
idx = idx.index # Convert RefTile to tensor
187+
188+
if isinstance(idx, torch.Tensor):
189+
dim_size = tensor.shape[i] if i < len(tensor.shape) else tensor.numel()
190+
orig_indices.append(idx)
191+
safe_indices.append(torch.clamp(idx, 0, dim_size - 1))
192+
is_tensor_mask.append(True)
193+
else:
194+
orig_indices.append(idx)
195+
safe_indices.append(idx)
196+
is_tensor_mask.append(False)
197+
198+
# Apply broadcasting if we have multiple tensor indices
199+
tensor_positions = [i for i, is_tensor in enumerate(is_tensor_mask) if is_tensor]
200+
201+
if len(tensor_positions) > 1:
202+
# Add unsqueeze operations for broadcasting
203+
broadcast_indices = []
204+
for i, (idx, is_tensor) in enumerate(
205+
zip(safe_indices, is_tensor_mask, strict=False)
206+
):
207+
if is_tensor:
208+
new_idx = idx
209+
# Add dimension for each other tensor index
210+
for j, other_pos in enumerate(tensor_positions):
211+
if other_pos != i:
212+
new_idx = new_idx.unsqueeze(j if other_pos < i else -1)
213+
broadcast_indices.append(new_idx)
214+
else:
215+
broadcast_indices.append(idx)
216+
values = tensor[tuple(broadcast_indices)]
217+
else:
218+
values = tensor[tuple(safe_indices)]
219+
220+
# Build validity mask
221+
valid_mask = extra_mask.clone()
222+
for i, (orig_idx, is_tensor) in enumerate(
223+
zip(orig_indices, is_tensor_mask, strict=False)
224+
):
225+
if is_tensor:
226+
dim_size = tensor.shape[i] if i < len(tensor.shape) else tensor.numel()
227+
in_bounds = (orig_idx >= 0) & (orig_idx < dim_size)
228+
# Broadcast to match mask shape by adding dimensions
229+
# Count how many tensor indices come before and after this one
230+
n_before = sum(1 for j in range(i) if is_tensor_mask[j])
231+
n_after = sum(
232+
1 for j in range(i + 1, len(is_tensor_mask)) if is_tensor_mask[j]
233+
)
234+
235+
# Add dimensions: n_after dimensions at the end, n_before at the beginning
236+
for _ in range(n_after):
237+
in_bounds = in_bounds.unsqueeze(-1)
238+
for _ in range(n_before):
239+
in_bounds = in_bounds.unsqueeze(0)
240+
valid_mask = valid_mask & in_bounds
241+
242+
return torch.where(valid_mask, values, result)
243+
244+
132245
@has_side_effect
133246
@_decorators.api(allow_host_tensor=True)
134247
def atomic_add(
@@ -210,6 +323,59 @@ def _(
210323
return None
211324

212325

326+
@_decorators.ref(atomic_add)
327+
def _(
328+
target: torch.Tensor,
329+
index: list[object],
330+
value: torch.Tensor | float,
331+
sem: str = "relaxed",
332+
) -> None:
333+
"""Reference implementation of atomic_add for interpret mode."""
334+
from .. import exc
335+
from .ref_tile import RefTile
336+
337+
# Validate sem parameter
338+
if sem not in ["relaxed", "acquire", "release", "acq_rel"]:
339+
raise exc.InternalError(
340+
ValueError(
341+
f"Invalid memory semantic '{sem}'. Valid options are: relaxed, acquire, release, acq_rel"
342+
)
343+
)
344+
345+
# Convert indices to proper format
346+
processed_index = []
347+
for idx in index:
348+
if isinstance(idx, RefTile):
349+
processed_index.append(idx._slice)
350+
elif isinstance(idx, torch.Tensor) and idx.numel() == 1:
351+
processed_index.append(int(idx.item()))
352+
else:
353+
processed_index.append(idx)
354+
355+
# Find tensor indices that need element-wise processing
356+
tensor_indices = [
357+
(i, idx)
358+
for i, idx in enumerate(processed_index)
359+
if isinstance(idx, torch.Tensor) and idx.numel() > 1
360+
]
361+
362+
if tensor_indices:
363+
# Element-wise processing for tensor indices
364+
i, tensor_idx = tensor_indices[0] # Handle first tensor index
365+
for j, elem in enumerate(tensor_idx):
366+
new_index = processed_index.copy()
367+
new_index[i] = int(elem.item())
368+
val = (
369+
value[j]
370+
if isinstance(value, torch.Tensor) and value.numel() > 1
371+
else value
372+
)
373+
target[tuple(new_index)] += val
374+
else:
375+
# Direct atomic add
376+
target[tuple(processed_index)] += value
377+
378+
213379
@_decorators.codegen(atomic_add)
214380
def _(state: CodegenState) -> ast.AST:
215381
target = state.proxy_arg(0)

0 commit comments

Comments
 (0)