Skip to content

Commit e8f8a35

Browse files
pytorchbotStrongerXiwdziurdz
authored
[dynamo] replace unimplemented with unimplemented_v2 in variables/functions.py (pytorch#153533)
* [dynamo] replace `unimplemented` with `unimplemented_v2` in `variables/functions.py` (pytorch#151277) This addresses part of pytorch#147913. Pull Request resolved: pytorch#151277 Approved by: https://github.com/Skylion007, https://github.com/williamwen42 (cherry picked from commit 9e24f9b) * Fix missing module import graph_break_hints (pytorch#153609) --------- Co-authored-by: Ryan Guo <[email protected]> Co-authored-by: Witold Dziurdz <[email protected]>
1 parent bdec157 commit e8f8a35

File tree

2 files changed

+60
-37
lines changed

2 files changed

+60
-37
lines changed

test/dynamo/test_decorators.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -644,9 +644,7 @@ def trace_me(x, y):
644644
fn(torch.ones(10), torch.ones(1))
645645
self.assertFalse(True) # must raise error before this
646646
except torch._dynamo.exc.Unsupported as e:
647-
msg = """
648-
Applying `nonstrict_trace` to function <trace_me>; however, `nonstrict_trace` currently requires the function to be defined outside `torch.compile` region.
649-
""" # NOQA: B950
647+
msg = "Applying `nonstrict_trace` to function <trace_me>; however, `nonstrict_trace` currently requires the function to be defined outside `torch.compile` region." # NOQA: B950
650648
self.assertIn(msg, str(e))
651649

652650
def test_nonstrict_trace_custom_class_error(self):

torch/_dynamo/variables/functions.py

Lines changed: 59 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636

3737
import torch
3838

39-
from .. import polyfills, variables
39+
from .. import graph_break_hints, polyfills, variables
4040
from ..bytecode_transformation import create_call_function, create_rot_n, is_generator
4141
from ..exc import (
4242
get_dynamo_observed_exception,
@@ -47,7 +47,6 @@
4747
ObservedUserStopIteration,
4848
raise_observed_exception,
4949
SkipFrame,
50-
unimplemented,
5150
unimplemented_v2,
5251
Unsupported,
5352
)
@@ -368,16 +367,27 @@ def call_function(
368367
fn_var = bound.args[0]
369368
if not isinstance(fn_var, BaseUserFunctionVariable):
370369
typ = fn_var.python_type()
371-
unimplemented(
372-
f"`nonstrict_trace` expects a callable, but got value of type <{typ.__name__}>"
370+
msg = f"`nonstrict_trace` expects a callable, but got value of type <{typ.__name__}>"
371+
unimplemented_v2(
372+
gb_type="TypeError from user code",
373+
context=f"call_function({self.value}, {args}, {kwargs})",
374+
explanation=msg,
375+
hints=[
376+
*graph_break_hints.USER_ERROR,
377+
],
373378
)
374379

375380
if not isinstance(fn_var, UserFunctionVariable):
376381
fn_name = fn_var.get_name()
377-
unimplemented(
378-
f"""
379-
Applying `nonstrict_trace` to function <{fn_name}>; however, `nonstrict_trace` currently requires the function to be defined outside `torch.compile` region.
380-
""" # NOQA: B950
382+
msg = f"Applying `nonstrict_trace` to function <{fn_name}>; however, `nonstrict_trace` currently requires the function to be defined outside `torch.compile` region." # noqa: B950
383+
unimplemented_v2(
384+
gb_type="Limitation of `nonstrict_trace",
385+
context=f"{self}",
386+
explanation=msg,
387+
hints=[
388+
f"make sure definition of {fn_name} is outside ",
389+
"`torch.compile` region",
390+
],
381391
)
382392

383393
fn = fn_var.fn
@@ -1189,16 +1199,10 @@ def call_function(
11891199
"Remove the `torch._dynamo.graph_break()` call.",
11901200
],
11911201
)
1192-
elif isinstance(self.value, types.WrapperDescriptorType):
1193-
msg = (
1194-
f"Graph break due to unsupported wrapper descriptor {self.value}. "
1195-
f"Please file an issue on GitHub "
1196-
f"so the PyTorch team can add support for it. "
1197-
)
1198-
torch._dynamo.utils.warn_once(msg)
1199-
unimplemented(msg)
12001202
else:
12011203
qualname = getattr(self.value, "__qualname__", "<unknown qualname>")
1204+
module_or = getattr(self.value, "__module__", None)
1205+
module_name = "<unknown module>" if module_or is None else str(module_or)
12021206
try:
12031207
path = inspect.getfile(self.value)
12041208
explanation = (
@@ -1221,22 +1225,19 @@ def call_function(
12211225
]
12221226
except TypeError:
12231227
known_python_builtin_modules = {"_abc", "_warnings"}
1224-
if self.value.__module__ in known_python_builtin_modules:
1228+
if module_or in known_python_builtin_modules:
12251229
explanation = (
12261230
f"Dynamo does not know how to trace the Python builtin "
1227-
f"`{self.value.__module__}.{qualname}`."
1231+
f"`{module_name}.{qualname}`."
12281232
)
12291233
hints = [
12301234
"If you are attempting to call a logging function (e.g. `_warnings.warn`), "
12311235
"you can try adding it to `torch._dynamo.config.reorderable_logging_functions`.",
12321236
"Please file an issue on GitHub "
12331237
"so the PyTorch team can add support for it. ",
12341238
]
1235-
elif (
1236-
self.value.__module__ is not None
1237-
and self.value.__module__.startswith("optree")
1238-
):
1239-
explanation = f"Dynamo cannot trace optree C/C++ function {self.value.__module__}.{qualname}."
1239+
elif module_or is not None and module_or.startswith("optree"):
1240+
explanation = f"Dynamo cannot trace optree C/C++ function {module_name}.{qualname}."
12401241
hints = [
12411242
" Consider using torch.utils._pytree - "
12421243
"https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py"
@@ -1245,7 +1246,7 @@ def call_function(
12451246
torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints))
12461247
else:
12471248
explanation = (
1248-
f"Dynamo does not know how to trace the builtin `{self.value.__module__}.{qualname}.` "
1249+
f"Dynamo does not know how to trace the builtin `{module_name}.{qualname}.` "
12491250
f"This function is either a Python builtin (e.g. _warnings.warn) "
12501251
f"or a third-party C/C++ Python extension (perhaps created with pybind)."
12511252
)
@@ -1270,7 +1271,7 @@ def call_function(
12701271
reason = self.reason if self.reason else "<missing reason>"
12711272
unimplemented_v2(
12721273
gb_type="Attempted to call function marked as skipped",
1273-
context=f"module: {self.value.__module__}, qualname: {qualname}, skip reason: {reason}",
1274+
context=f"module: {module_name}, qualname: {qualname}, skip reason: {reason}",
12741275
explanation=explanation,
12751276
hints=hints,
12761277
)
@@ -1395,8 +1396,13 @@ def call_function(
13951396
args = ()
13961397

13971398
if "async_op" in kwargs and kwargs["async_op"].as_python_constant():
1398-
unimplemented(
1399-
f"CollectiveFunctionRewriteVariable can't support async_op=True for {self.fn}"
1399+
unimplemented_v2(
1400+
gb_type="async_op=True for distributed collectives",
1401+
context=f"{self.fn}, {args=}, {kwargs=}",
1402+
explanation=f"`torch.compile` doesn't support `async_op=True for {self.fn}",
1403+
hints=[
1404+
*graph_break_hints.SUPPORTABLE,
1405+
],
14001406
)
14011407

14021408
if self.fn in (
@@ -1430,7 +1436,14 @@ def call_function(
14301436
def wraps(fn):
14311437
if isinstance(fn, variables.NestedUserFunctionVariable):
14321438
return fn.clone(wrapped_fn=args[0])
1433-
unimplemented(f"functools.wraps({fn})")
1439+
unimplemented_v2(
1440+
gb_type="functools.wraps",
1441+
context=f"{fn}",
1442+
explanation="`torch.compile` can't trace `functools.wraps` on functions defined outside the compile region",
1443+
hints=[
1444+
*graph_break_hints.SUPPORTABLE,
1445+
],
1446+
)
14341447

14351448
return variables.LambdaVariable(wraps)
14361449

@@ -1456,7 +1469,14 @@ def call_function(
14561469
return variables.UserDefinedClassVariable(
14571470
value, mutation_type=ValueMutationNew()
14581471
)
1459-
unimplemented("namedtuple with non constant args")
1472+
unimplemented_v2(
1473+
gb_type="namedtuple construction",
1474+
context=f"{args=}, {kwargs=}",
1475+
explanation="`torch.compile` only support certain input types for namedtuple",
1476+
hints=[
1477+
*graph_break_hints.SUPPORTABLE,
1478+
],
1479+
)
14601480

14611481

14621482
class FunctoolsPartialVariable(VariableTracker):
@@ -1701,10 +1721,8 @@ def exception(self, tx):
17011721
def call_function(self, tx, args, kwargs):
17021722
if self.value is sys.exc_info:
17031723
return self.exc_info(tx)
1704-
elif self.value is sys.exception:
1705-
return self.exception(tx)
1706-
else:
1707-
unimplemented(f"sys.{self.value.__name__}")
1724+
assert self.value is sys.exception
1725+
return self.exception(tx)
17081726

17091727

17101728
from torch._higher_order_ops.triton_kernel_wrap import (
@@ -1731,7 +1749,14 @@ def check_grid(self, grid) -> tuple[torch.fx.proxy.Proxy, ...]:
17311749
if isinstance(grid, BaseListVariable):
17321750
return grid.as_proxy()
17331751
else:
1734-
unimplemented(f"grid for the triton kernel is {type(grid)}")
1752+
unimplemented_v2(
1753+
gb_type="unsupported grid type for triton hop check_grid",
1754+
context=f"grid type = {type(grid)}",
1755+
explanation="`torch.compile` only supports list-like grid for check_grid",
1756+
hints=[
1757+
*graph_break_hints.SUPPORTABLE,
1758+
],
1759+
)
17351760

17361761
def call_grid(self, grid, meta, tx):
17371762
meta = {variables.ConstantVariable.create(k): v for k, v in meta.items()}

0 commit comments

Comments
 (0)