Skip to content

Commit 5d1a3ac

Browse files
authored
Remove ErrorReporting class and simplify warning handling (#204)
1 parent 78b9c7a commit 5d1a3ac

File tree

5 files changed

+12
-107
lines changed

5 files changed

+12
-107
lines changed

helion/_compiler/compile_environment.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import collections
44
import contextlib
55
import dataclasses
6+
import sys
67
import threading
78
import types
89
import typing
@@ -19,7 +20,6 @@
1920

2021
from .. import exc
2122
from ..language.constexpr import ConstExpr
22-
from .error_reporting import ErrorReporting
2323
from .loop_dependency_checker import LoopDependencyChecker
2424
from .variable_origin import BlockSizeOrigin
2525
from .variable_origin import Origin
@@ -55,7 +55,6 @@ def __init__(self, device: torch.device, settings: Settings) -> None:
5555
super().__init__()
5656
self.device = device
5757
self.settings = settings
58-
self.errors = ErrorReporting(settings)
5958
self.shape_env = ShapeEnv(
6059
specialize_zero_one=True,
6160
duck_shape=False,
@@ -293,7 +292,6 @@ def __enter__(self) -> Self:
293292
assert getattr(tls, "env", None) is None, "CompileEnvironment already active"
294293
self.fake_mode.__enter__()
295294
tls.env = self
296-
self.errors = ErrorReporting(self.settings) # clear prior errors
297295
self.loop_dependency_checker = LoopDependencyChecker()
298296
return self
299297

@@ -305,7 +303,6 @@ def __exit__(
305303
) -> None:
306304
tls.env = None
307305
self.fake_mode.__exit__(exc_type, exc_value, traceback)
308-
self.errors.raise_if_errors()
309306

310307
@staticmethod
311308
def current() -> CompileEnvironment:
@@ -482,7 +479,17 @@ def from_config(self, config: Config, block_id: int) -> int | None:
482479

483480

484481
def warning(warning: exc.BaseWarning | type[exc.BaseWarning]) -> None:
485-
CompileEnvironment.current().errors.add(warning)
482+
"""Print a warning to stderr if it's not in the ignore list."""
483+
env = CompileEnvironment.current()
484+
if callable(warning):
485+
warning = warning()
486+
487+
if not isinstance(warning, exc.BaseWarning):
488+
raise TypeError(f"expected BaseWarning, got {type(warning)}")
489+
490+
# Check if this warning type should be ignored
491+
if not isinstance(warning, tuple(env.settings.ignore_warnings)):
492+
print(f"WARNING[{type(warning).__name__}]: {warning.args[0]}", file=sys.stderr)
486493

487494

488495
def _to_sympy(x: int | torch.SymInt) -> sympy.Expr:

helion/_compiler/device_ir.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -883,7 +883,6 @@ def lower_to_device_ir(func: HostFunction) -> DeviceIR:
883883
visitor = WalkHostAST(device_ir)
884884
for stmt in func.body:
885885
visitor.visit(stmt)
886-
CompileEnvironment.current().errors.raise_if_errors()
887886
for graph in device_ir.graphs:
888887
prepare_graph_lowerings(graph.graph)
889888
for graph in device_ir.graphs:

helion/_compiler/error_reporting.py

Lines changed: 0 additions & 97 deletions
This file was deleted.

helion/_compiler/generate_ast.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,6 @@ def generate_ast(func: HostFunction, config: Config) -> ast.AST:
399399
with codegen.device_function:
400400
for stmt in func.body:
401401
codegen.add_statement(codegen.visit(stmt))
402-
CompileEnvironment.current().errors.raise_if_errors()
403402
kernel_def = codegen.device_function.codegen_function_def()
404403
host_def = func.codegen_function_def(codegen.host_statements)
405404
precompile_def = codegen_precompile_def(

helion/_compiler/host_function.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,6 @@ def __init__(
103103
from .type_propagation import propagate_types
104104

105105
propagate_types(self, fake_args)
106-
env.errors.raise_if_errors()
107106
env.finalize_config_spec()
108107
self.device_ir = lower_to_device_ir(self)
109108

@@ -208,8 +207,6 @@ def debug_str(self) -> str:
208207
),
209208
self.device_ir.debug_str(),
210209
]
211-
if error_str := CompileEnvironment.current().errors.report(strip_paths=True):
212-
result.extend(error_str)
213210
return "\n\n".join(result)
214211

215212
def codegen_function_def(self, statements: list[ast.AST]) -> ast.FunctionDef:

0 commit comments

Comments
 (0)