Skip to content

Commit 8b9815a

Browse files
committed
try fix Tile(block_id) error
1 parent cdfcbc8 commit 8b9815a

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

benchmarks/run.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,23 @@ def run_kernel_variants(
325325
operator_args: dict[str, Any] | None = None,
326326
) -> None:
327327
"""Run kernel variants in the same benchmark run."""
328+
329+
# Configure Helion to use fewer generations for faster autotuning during benchmarks
330+
import helion
331+
from helion.autotuner import DifferentialEvolutionSearch, LocalAutotuneCache
332+
from helion.runtime.kernel import BoundKernel
333+
from typing import Sequence
334+
335+
def fast_autotuner_fn(
336+
bound_kernel: BoundKernel, args: Sequence[object], **kwargs: object
337+
) -> LocalAutotuneCache:
338+
# Use only 1 generation instead of default 20 for faster benchmarking
339+
return LocalAutotuneCache(
340+
DifferentialEvolutionSearch(bound_kernel, args, num_generations=1, **kwargs)
341+
)
342+
343+
# Set the custom autotuner function
344+
helion.set_default_settings(helion.Settings(autotuner_fn=fast_autotuner_fn))
328345

329346
# Configure Helion to use fewer generations for faster autotuning during benchmarks
330347
import helion

helion/_compiler/type_propagation.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -990,12 +990,19 @@ def __init__(self, origin: Origin, block_id: int) -> None:
990990
self.block_id = block_id
991991

992992
def proxy(self) -> object:
993+
from ..language.tile_proxy import Tile as TileClass
994+
993995
with proxy_tensor.disable_proxy_modes_tracing():
994996
fake_mode = torch._C._unset_dispatch_mode( # pyright: ignore[reportAttributeAccessIssue]
995997
torch._C._TorchDispatchModeKey.FAKE # pyright: ignore[reportAttributeAccessIssue]
996998
)
997999
try:
998-
return Tile(self.block_id)
1000+
# Create a Tile instance using torch.as_subclass to properly handle tensor subclassing
1001+
# This avoids the "already associated to a python object" error
1002+
base_tensor = torch.empty([], dtype=torch.int64, device='meta')
1003+
tile = base_tensor.as_subclass(TileClass)
1004+
tile.block_id = self.block_id
1005+
return tile
9991006
finally:
10001007
assert fake_mode is not None
10011008
torch._C._set_dispatch_mode(fake_mode) # pyright: ignore[reportAttributeAccessIssue]

0 commit comments

Comments
 (0)