@@ -243,7 +243,9 @@ def make_input_proxy(arg_node):
243
243
return proxy_args , proxy_kwargs
244
244
245
245
246
- def try_execute_thunder_symbol (thunder_symbol : Symbol , node : torch .fx .Node ) -> tuple [bool , SplitReason | None ]:
246
+ def try_execute_thunder_symbol (
247
+ thunder_symbol : Symbol , node : torch .fx .Node , thunder_options : dict [str , Any ]
248
+ ) -> tuple [bool , SplitReason | None ]:
247
249
"""
248
250
Attempts to execute a given Thunder symbol within a tracing context, using proxies for the node's arguments.
249
251
@@ -287,6 +289,7 @@ def get_requires_grad(arg_node):
287
289
288
290
args , _ = tree_flatten ((node .args , node .kwargs ))
289
291
requires_grad = any (map (get_requires_grad , args ))
292
+ disable_torch_autograd : bool | None = thunder_options .get ("disable_torch_autograd" , None )
290
293
291
294
@compile_data_and_stats (cd , cs )
292
295
@thunder ._with_cache_info_ctx
@@ -309,7 +312,12 @@ def _run_with_cache_info():
309
312
exception = str (e ),
310
313
)
311
314
312
- function_to_run = value_and_grad (thunder_symbol ) if requires_grad else thunder_symbol
315
+ function_to_run = thunder_symbol
316
+ function_to_run = (
317
+ value_and_grad (thunder_symbol )
318
+ if requires_grad and (disable_torch_autograd is None or not disable_torch_autograd )
319
+ else thunder_symbol
320
+ )
313
321
# We need to be under trace context to generate proxies.
314
322
with thunder .core .trace .tracectx (TraceCtx ()):
315
323
try :
@@ -351,7 +359,7 @@ def get_nodes_in_unsupported_ctx_regions(gm: torch.fx.GraphModule) -> set[torch.
351
359
return nodes_in_unsupported_ctx_regions
352
360
353
361
354
- def is_graphmodule_supported_by_thunder (gm ):
362
+ def is_graphmodule_supported_by_thunder (gm , thunder_options : dict [ str , Any ] ):
355
363
nodes_in_unsupported_ctx_regions = get_nodes_in_unsupported_ctx_regions (gm )
356
364
for node in gm .graph .nodes :
357
365
if node .op in (
@@ -367,13 +375,15 @@ def is_graphmodule_supported_by_thunder(gm):
367
375
)
368
376
return False , split_reason
369
377
370
- is_thunder_supported , split_reason = is_node_supported_by_thunder (node )
378
+ is_thunder_supported , split_reason = is_node_supported_by_thunder (node , thunder_options )
371
379
if not is_thunder_supported :
372
380
return False , split_reason
373
381
return True , None
374
382
375
383
376
- def is_node_supported_by_thunder (node : torch .fx .Node ) -> tuple [bool , SplitReason | None ]:
384
+ def is_node_supported_by_thunder (
385
+ node : torch .fx .Node , thunder_options : dict [str , Any ]
386
+ ) -> tuple [bool , SplitReason | None ]:
377
387
"""
378
388
Determine whether thunder can execute the operation described by this node.
379
389
"""
@@ -425,7 +435,7 @@ def is_node_supported_by_thunder(node: torch.fx.Node) -> tuple[bool, SplitReason
425
435
for arg_node in node .args :
426
436
if arg_node .op == "get_attr" :
427
437
called_module = getattr (m , arg_node .target )
428
- is_module_supported , split_reason = is_graphmodule_supported_by_thunder (called_module )
438
+ is_module_supported , split_reason = is_graphmodule_supported_by_thunder (called_module , thunder_options )
429
439
if not is_module_supported :
430
440
return is_module_supported , split_reason
431
441
return True , None
@@ -438,7 +448,7 @@ def is_node_supported_by_thunder(node: torch.fx.Node) -> tuple[bool, SplitReason
438
448
# We try to proxify the arguments and call these operations on them to see if they are supported.
439
449
if target in _torch_to_thunder_function_map or inspect .isbuiltin (target ):
440
450
thunder_symbol_or_builtin = _torch_to_thunder_function_map .get (target , target )
441
- did_run , opt_split_reason = try_execute_thunder_symbol (thunder_symbol_or_builtin , node )
451
+ did_run , opt_split_reason = try_execute_thunder_symbol (thunder_symbol_or_builtin , node , thunder_options )
442
452
return did_run , opt_split_reason
443
453
444
454
# There are few operations which are registered only as method in `torchctx` and hence they don't exist
@@ -457,7 +467,7 @@ def is_node_supported_by_thunder(node: torch.fx.Node) -> tuple[bool, SplitReason
457
467
)
458
468
# NOTE: `get_method` may throw if relevant method is not found, so we have guarded it with `has_method`.
459
469
method = torchctx .get_method (node .target , args , kwargs )
460
- did_run , opt_split_reason = try_execute_thunder_symbol (method , node )
470
+ did_run , opt_split_reason = try_execute_thunder_symbol (method , node , thunder_options )
461
471
return did_run , opt_split_reason
462
472
463
473
# checks einops operators
0 commit comments