diff --git a/thunder/dynamo/utils.py b/thunder/dynamo/utils.py index 3c58bc2ded..a30e9ecc2e 100644 --- a/thunder/dynamo/utils.py +++ b/thunder/dynamo/utils.py @@ -650,6 +650,9 @@ def _checkpoint_function_converter(gm: torch.fx.GraphModule): thunder_node = gm.graph.call_function( _torch_to_thunder_function_map[n.target], args=n.args, kwargs=n.kwargs ) + # Copy metadata from the original node to preserve tensor properties like + # requires_grad, dtype, shape, etc. which are crucial for gradient computation + thunder_node.meta = n.meta.copy() n.replace_all_uses_with(thunder_node) gm.graph.erase_node(n) else: