Skip to content

Commit eea8ff2

Browse files
amaldevhpytorchmergebot
authored andcommitted
Fix torch.full with dynamic tensor fill_value in torch.compile (pytorch#166554)
Fixes pytorch#166253 ## Summary When `torch.full` is called with a 0-D tensor as `fill_value` inside a `torch.compile`'d function, the value was being incorrectly cached, causing subsequent calls with different values to return the first value. ## Root Cause The Dynamo handler for `torch.full` was calling `aten._local_scalar_dense` to convert tensor fill_values to Python scalars at compile time, which baked the value into the compiled graph as a constant. ## Solution Modified the Dynamo handler to decompose `torch.full(size, tensor_fill_value)` into `empty(size).fill_(tensor_fill_value)` when `fill_value` is a `TensorVariable`, keeping the fill value dynamic in the compiled graph. ## Testing Added test case that verifies torch.full works correctly with dynamic tensor fill_values across multiple calls and dtypes. Pull Request resolved: pytorch#166554 Approved by: https://github.com/Lucaskabela
1 parent 11f73d7 commit eea8ff2

File tree

2 files changed

+63
-5
lines changed

2 files changed

+63
-5
lines changed

test/dynamo/test_functions.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5241,6 +5241,63 @@ def forward(self, x):
52415241
x = torch.randn(1)
52425242
self.assertEqual(opt_mod(x), x + 1)
52435243

5244+
def test_full_with_tensor_fill_value(self):
5245+
"""Test that torch.full works correctly with dynamic tensor fill_value"""
5246+
5247+
# Test with tensor fill_value (the bug case)
5248+
def func_tensor(x):
5249+
return torch.full((2,), x, dtype=torch.float64)
5250+
5251+
func_compiled = torch.compile(func_tensor)
5252+
5253+
# Test with different values
5254+
x1 = torch.tensor(5.0, dtype=torch.float64)
5255+
x2 = torch.tensor(10.0, dtype=torch.float64)
5256+
5257+
result1 = func_compiled(x1)
5258+
expected1 = torch.full((2,), x1, dtype=torch.float64)
5259+
self.assertEqual(result1, expected1)
5260+
5261+
# This is where the bug occurred - second call reused first value
5262+
result2 = func_compiled(x2)
5263+
expected2 = torch.full((2,), x2, dtype=torch.float64)
5264+
self.assertEqual(result2, expected2)
5265+
5266+
# Test with different dtypes
5267+
for dtype in [torch.float32, torch.float64, torch.int32, torch.int64]:
5268+
5269+
def func_typed(x):
5270+
return torch.full((3,), x, dtype=dtype)
5271+
5272+
func_typed_compiled = torch.compile(func_typed)
5273+
x_typed = torch.tensor(7, dtype=dtype)
5274+
result = func_typed_compiled(x_typed)
5275+
expected = torch.full((3,), x_typed, dtype=dtype)
5276+
self.assertEqual(result, expected)
5277+
5278+
# Test with non-tensor fill_value (scalar) to ensure we didn't break existing behavior
5279+
def func_scalar(size):
5280+
return torch.full((size,), 42.0, dtype=torch.float32)
5281+
5282+
func_scalar_compiled = torch.compile(func_scalar)
5283+
5284+
result_scalar = func_scalar_compiled(5)
5285+
expected_scalar = torch.full((5,), 42.0, dtype=torch.float32)
5286+
self.assertEqual(result_scalar, expected_scalar)
5287+
5288+
# Test with different scalar values
5289+
def func_scalar_param():
5290+
# Test multiple calls with different hardcoded scalar values
5291+
a = torch.full((2,), 3.14, dtype=torch.float32)
5292+
b = torch.full((2,), 2.71, dtype=torch.float32)
5293+
return a, b
5294+
5295+
func_scalar_param_compiled = torch.compile(func_scalar_param)
5296+
result_a, result_b = func_scalar_param_compiled()
5297+
5298+
self.assertEqual(result_a, torch.full((2,), 3.14, dtype=torch.float32))
5299+
self.assertEqual(result_b, torch.full((2,), 2.71, dtype=torch.float32))
5300+
52445301

52455302
instantiate_parametrized_tests(FunctionTests)
52465303
instantiate_parametrized_tests(DefaultsTests)

torch/_dynamo/variables/torch.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -834,12 +834,13 @@ def handle_addcdiv(self, tx: "InstructionTranslator", *args, **kwargs):
834834
@register(torch.full)
835835
def handle_full(self, tx, size, fill_value, **kwargs):
836836
if isinstance(fill_value, TensorVariable):
837-
result = TorchInGraphFunctionVariable(
838-
torch.ops.aten._local_scalar_dense
839-
).call_function(tx, [fill_value], {})
840-
return TorchInGraphFunctionVariable(torch.full).call_function(
841-
tx, [size, result], kwargs
837+
# Decompose: create empty tensor and fill it
838+
# This avoids the scalar extraction at compile time
839+
empty_result = TorchInGraphFunctionVariable(torch.empty).call_function(
840+
tx, [size], kwargs
842841
)
842+
# Call fill_ method on the empty tensor
843+
return empty_result.call_method(tx, "fill_", [fill_value], {})
843844

844845
@register(torch._foreach_lerp_)
845846
def handle_inplace_foreach_lerp_scalar(

0 commit comments

Comments
 (0)