Skip to content

Commit aabe823

Browse files
authored
Add extensive setter/getter unit tests for indexed tensor; fix bugs discovered by new tests (#422)
1 parent 27f7f1c commit aabe823

File tree

4 files changed

+470
-6
lines changed

4 files changed

+470
-6
lines changed

helion/_compiler/indexing_strategy.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -518,7 +518,10 @@ def boundary_check(self, state: CodegenState) -> str:
518518
return repr(result)
519519
return "None"
520520

521-
def need_reshape(self) -> bool:
521+
def need_reshape(self, node: ast.AST) -> bool:
522+
if isinstance(node, ast.Constant):
523+
# Don't reshape scalar constants - they will be broadcast automatically
524+
return False
522525
if len(self.reshaped_size) != len(self.block_shape):
523526
return True
524527
env = CompileEnvironment.current()
@@ -528,13 +531,13 @@ def need_reshape(self) -> bool:
528531
return False
529532

530533
def reshape_load(self, state: CodegenState, node: ast.AST) -> ast.AST:
531-
if not self.need_reshape():
534+
if not self.need_reshape(node):
532535
return node
533536
shape = state.tile_strategy.shape_str(self.reshaped_size)
534537
return expr_from_string(f"tl.reshape(node, {shape})", node=node)
535538

536539
def reshape_store(self, state: CodegenState, node: ast.AST) -> ast.AST:
537-
if not self.need_reshape():
540+
if not self.need_reshape(node):
538541
return node
539542
shape = state.tile_strategy.shape_str(self.block_shape)
540543
return expr_from_string(f"tl.reshape(node, {shape})", node=node)

helion/_testing.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,11 @@ def skipIfRefEager(reason: str) -> Callable[[Callable], Callable]:
4040
return unittest.skipIf(os.environ.get("HELION_INTERPRET") == "1", reason)
4141

4242

43+
def skipIfNormalMode(reason: str) -> Callable[[Callable], Callable]:
44+
"""Skip test if running in normal mode (i.e. if HELION_INTERPRET=1 is not set)."""
45+
return unittest.skipIf(os.environ.get("HELION_INTERPRET") != "1", reason)
46+
47+
4348
@contextlib.contextmanager
4449
def track_run_ref_calls() -> Generator[list[int], None, None]:
4550
"""Context manager that tracks BoundKernel.run_ref calls.

helion/autotuner/config_generation.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,10 @@ def _collect_spec(spec: ConfigSpecFragment) -> object:
5252
for i, spec in enumerate(self.flat_spec)
5353
if spec.category() == Category.NUM_WARPS
5454
)
55-
self.min_block_size: int = max(
56-
[spec.min_size for spec in config_spec.block_sizes]
55+
self.min_block_size: int = (
56+
max([spec.min_size for spec in config_spec.block_sizes])
57+
if config_spec.block_sizes
58+
else 1
5759
)
5860

5961
def unflatten(self, flat_values: FlatConfig) -> Config:
@@ -80,7 +82,9 @@ def get_next_value(spec: ConfigSpecFragment) -> object:
8082

8183
def block_numel(self, flat_config: FlatConfig) -> int:
8284
return functools.reduce(
83-
operator.mul, [cast("int", flat_config[i]) for i in self.block_size_indices]
85+
operator.mul,
86+
[cast("int", flat_config[i]) for i in self.block_size_indices],
87+
1,
8488
)
8589

8690
def shrink_config(

0 commit comments

Comments
 (0)