3636
3737import torch
3838
39- from .. import polyfills , variables
39+ from .. import graph_break_hints , polyfills , variables
4040from ..bytecode_transformation import create_call_function , create_rot_n , is_generator
4141from ..exc import (
4242 get_dynamo_observed_exception ,
4747 ObservedUserStopIteration ,
4848 raise_observed_exception ,
4949 SkipFrame ,
50- unimplemented ,
5150 unimplemented_v2 ,
5251 Unsupported ,
5352)
@@ -368,16 +367,27 @@ def call_function(
368367 fn_var = bound .args [0 ]
369368 if not isinstance (fn_var , BaseUserFunctionVariable ):
370369 typ = fn_var .python_type ()
371- unimplemented (
372- f"`nonstrict_trace` expects a callable, but got value of type <{ typ .__name__ } >"
370+ msg = f"`nonstrict_trace` expects a callable, but got value of type <{ typ .__name__ } >"
371+ unimplemented_v2 (
372+ gb_type = "TypeError from user code" ,
373+ context = f"call_function({ self .value } , { args } , { kwargs } )" ,
374+ explanation = msg ,
375+ hints = [
376+ * graph_break_hints .USER_ERROR ,
377+ ],
373378 )
374379
375380 if not isinstance (fn_var , UserFunctionVariable ):
376381 fn_name = fn_var .get_name ()
377- unimplemented (
378- f"""
379- Applying `nonstrict_trace` to function <{ fn_name } >; however, `nonstrict_trace` currently requires the function to be defined outside `torch.compile` region.
380- """ # NOQA: B950
382+ msg = f"Applying `nonstrict_trace` to function <{ fn_name } >; however, `nonstrict_trace` currently requires the function to be defined outside `torch.compile` region." # noqa: B950
383+ unimplemented_v2 (
384+ gb_type = "Limitation of `nonstrict_trace" ,
385+ context = f"{ self } " ,
386+ explanation = msg ,
387+ hints = [
388+ f"make sure definition of { fn_name } is outside " ,
389+ "`torch.compile` region" ,
390+ ],
381391 )
382392
383393 fn = fn_var .fn
@@ -1189,16 +1199,10 @@ def call_function(
11891199 "Remove the `torch._dynamo.graph_break()` call." ,
11901200 ],
11911201 )
1192- elif isinstance (self .value , types .WrapperDescriptorType ):
1193- msg = (
1194- f"Graph break due to unsupported wrapper descriptor { self .value } . "
1195- f"Please file an issue on GitHub "
1196- f"so the PyTorch team can add support for it. "
1197- )
1198- torch ._dynamo .utils .warn_once (msg )
1199- unimplemented (msg )
12001202 else :
12011203 qualname = getattr (self .value , "__qualname__" , "<unknown qualname>" )
1204+ module_or = getattr (self .value , "__module__" , None )
1205+ module_name = "<unknown module>" if module_or is None else str (module_or )
12021206 try :
12031207 path = inspect .getfile (self .value )
12041208 explanation = (
@@ -1221,22 +1225,19 @@ def call_function(
12211225 ]
12221226 except TypeError :
12231227 known_python_builtin_modules = {"_abc" , "_warnings" }
1224- if self . value . __module__ in known_python_builtin_modules :
1228+ if module_or in known_python_builtin_modules :
12251229 explanation = (
12261230 f"Dynamo does not know how to trace the Python builtin "
1227- f"`{ self . value . __module__ } .{ qualname } `."
1231+ f"`{ module_name } .{ qualname } `."
12281232 )
12291233 hints = [
12301234 "If you are attempting to call a logging function (e.g. `_warnings.warn`), "
12311235 "you can try adding it to `torch._dynamo.config.reorderable_logging_functions`." ,
12321236 "Please file an issue on GitHub "
12331237 "so the PyTorch team can add support for it. " ,
12341238 ]
1235- elif (
1236- self .value .__module__ is not None
1237- and self .value .__module__ .startswith ("optree" )
1238- ):
1239- explanation = f"Dynamo cannot trace optree C/C++ function { self .value .__module__ } .{ qualname } ."
1239+ elif module_or is not None and module_or .startswith ("optree" ):
1240+ explanation = f"Dynamo cannot trace optree C/C++ function { module_name } .{ qualname } ."
12401241 hints = [
12411242 " Consider using torch.utils._pytree - "
12421243 "https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py"
@@ -1245,7 +1246,7 @@ def call_function(
12451246 torch ._dynamo .utils .warn_once (explanation + "\n " + "\n " .join (hints ))
12461247 else :
12471248 explanation = (
1248- f"Dynamo does not know how to trace the builtin `{ self . value . __module__ } .{ qualname } .` "
1249+ f"Dynamo does not know how to trace the builtin `{ module_name } .{ qualname } .` "
12491250 f"This function is either a Python builtin (e.g. _warnings.warn) "
12501251 f"or a third-party C/C++ Python extension (perhaps created with pybind)."
12511252 )
@@ -1270,7 +1271,7 @@ def call_function(
12701271 reason = self .reason if self .reason else "<missing reason>"
12711272 unimplemented_v2 (
12721273 gb_type = "Attempted to call function marked as skipped" ,
1273- context = f"module: { self . value . __module__ } , qualname: { qualname } , skip reason: { reason } " ,
1274+ context = f"module: { module_name } , qualname: { qualname } , skip reason: { reason } " ,
12741275 explanation = explanation ,
12751276 hints = hints ,
12761277 )
@@ -1395,8 +1396,13 @@ def call_function(
13951396 args = ()
13961397
13971398 if "async_op" in kwargs and kwargs ["async_op" ].as_python_constant ():
1398- unimplemented (
1399- f"CollectiveFunctionRewriteVariable can't support async_op=True for { self .fn } "
1399+ unimplemented_v2 (
1400+ gb_type = "async_op=True for distributed collectives" ,
1401+ context = f"{ self .fn } , { args = } , { kwargs = } " ,
1402+ explanation = f"`torch.compile` doesn't support `async_op=True for { self .fn } " ,
1403+ hints = [
1404+ * graph_break_hints .SUPPORTABLE ,
1405+ ],
14001406 )
14011407
14021408 if self .fn in (
@@ -1430,7 +1436,14 @@ def call_function(
14301436 def wraps (fn ):
14311437 if isinstance (fn , variables .NestedUserFunctionVariable ):
14321438 return fn .clone (wrapped_fn = args [0 ])
1433- unimplemented (f"functools.wraps({ fn } )" )
1439+ unimplemented_v2 (
1440+ gb_type = "functools.wraps" ,
1441+ context = f"{ fn } " ,
1442+ explanation = "`torch.compile` can't trace `functools.wraps` on functions defined outside the compile region" ,
1443+ hints = [
1444+ * graph_break_hints .SUPPORTABLE ,
1445+ ],
1446+ )
14341447
14351448 return variables .LambdaVariable (wraps )
14361449
@@ -1456,7 +1469,14 @@ def call_function(
14561469 return variables .UserDefinedClassVariable (
14571470 value , mutation_type = ValueMutationNew ()
14581471 )
1459- unimplemented ("namedtuple with non constant args" )
1472+ unimplemented_v2 (
1473+ gb_type = "namedtuple construction" ,
1474+ context = f"{ args = } , { kwargs = } " ,
1475+ explanation = "`torch.compile` only support certain input types for namedtuple" ,
1476+ hints = [
1477+ * graph_break_hints .SUPPORTABLE ,
1478+ ],
1479+ )
14601480
14611481
14621482class FunctoolsPartialVariable (VariableTracker ):
@@ -1701,10 +1721,8 @@ def exception(self, tx):
17011721 def call_function (self , tx , args , kwargs ):
17021722 if self .value is sys .exc_info :
17031723 return self .exc_info (tx )
1704- elif self .value is sys .exception :
1705- return self .exception (tx )
1706- else :
1707- unimplemented (f"sys.{ self .value .__name__ } " )
1724+ assert self .value is sys .exception
1725+ return self .exception (tx )
17081726
17091727
17101728from torch ._higher_order_ops .triton_kernel_wrap import (
@@ -1731,7 +1749,14 @@ def check_grid(self, grid) -> tuple[torch.fx.proxy.Proxy, ...]:
17311749 if isinstance (grid , BaseListVariable ):
17321750 return grid .as_proxy ()
17331751 else :
1734- unimplemented (f"grid for the triton kernel is { type (grid )} " )
1752+ unimplemented_v2 (
1753+ gb_type = "unsupported grid type for triton hop check_grid" ,
1754+ context = f"grid type = { type (grid )} " ,
1755+ explanation = "`torch.compile` only supports list-like grid for check_grid" ,
1756+ hints = [
1757+ * graph_break_hints .SUPPORTABLE ,
1758+ ],
1759+ )
17351760
17361761 def call_grid (self , grid , meta , tx ):
17371762 meta = {variables .ConstantVariable .create (k ): v for k , v in meta .items ()}
0 commit comments