Skip to content

Commit efbf07e

Browse files
Revert "[dynamo] Fix issue with tensors passed as view() shapes (pytorch#156928)"
This reverts commit 75f3e5a. Reverted pytorch#156928 on behalf of https://github.com/jeanschmidt due to Breaks a internal test, more details can be found on D77449971 ([comment](pytorch#156928 (comment)))
1 parent 5e18bc3 commit efbf07e

File tree

2 files changed

+0
-111
lines changed

2 files changed

+0
-111
lines changed

test/dynamo/test_view.py

Lines changed: 0 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -33,86 +33,6 @@ def f(t, _n):
3333
t = torch.tensor([2, 4], dtype=torch.int32)
3434
f(t, 8)
3535

36-
def test_view_with_tensor_shape_params(self):
37-
# Test for issue #156720: aten.view.default with tensor shape parameters
38-
class TestModel(torch.nn.Module):
39-
def forward(self, x, shape_params):
40-
return torch.ops.aten.view.default(x, shape_params)
41-
42-
x = torch.randn(24)
43-
shape_params = [
44-
torch.tensor(2, dtype=torch.int32),
45-
torch.tensor(3, dtype=torch.int32),
46-
torch.tensor(4, dtype=torch.int32),
47-
]
48-
49-
model = TestModel()
50-
expected = model(x, shape_params)
51-
52-
compiled_model = torch.compile(model, backend="eager")
53-
result = compiled_model(x, shape_params)
54-
55-
torch.testing.assert_close(result, expected)
56-
57-
def test_tensor_view_with_tensor_shape_params(self):
58-
# Test tensor.view() method with tensor shape parameters (list version)
59-
class TestModel(torch.nn.Module):
60-
def forward(self, x, shape_params):
61-
return x.view(shape_params)
62-
63-
x = torch.randn(24)
64-
shape_params = (
65-
torch.tensor(2, dtype=torch.int32),
66-
torch.tensor(3, dtype=torch.int32),
67-
torch.tensor(4, dtype=torch.int32),
68-
)
69-
70-
model = TestModel()
71-
expected = model(x, shape_params)
72-
73-
compiled_model = torch.compile(model, backend="eager")
74-
result = compiled_model(x, shape_params)
75-
76-
torch.testing.assert_close(result, expected)
77-
78-
def test_tensor_view_with_tensor_args(self):
79-
# Test tensor.view() method with individual tensor arguments
80-
class TestModel(torch.nn.Module):
81-
def forward(self, x, dim1, dim2, dim3):
82-
return x.view(dim1, dim2, dim3)
83-
84-
x = torch.randn(24)
85-
dim1 = torch.tensor(2, dtype=torch.int32)
86-
dim2 = torch.tensor(3, dtype=torch.int32)
87-
dim3 = torch.tensor(4, dtype=torch.int32)
88-
89-
model = TestModel()
90-
expected = model(x, dim1, dim2, dim3)
91-
92-
compiled_model = torch.compile(model, backend="eager")
93-
result = compiled_model(x, dim1, dim2, dim3)
94-
95-
torch.testing.assert_close(result, expected)
96-
97-
def test_torch_reshape_with_tensor_shape_params(self):
98-
# Test torch.reshape() function with tensor shape parameters
99-
def test_fn(x, shape_params):
100-
return torch.reshape(x, shape_params)
101-
102-
x = torch.randn(24)
103-
shape_params = [
104-
torch.tensor(2, dtype=torch.int32),
105-
torch.tensor(3, dtype=torch.int32),
106-
torch.tensor(4, dtype=torch.int32),
107-
]
108-
109-
expected = test_fn(x, shape_params)
110-
111-
compiled_fn = torch.compile(test_fn, backend="eager")
112-
result = compiled_fn(x, shape_params)
113-
114-
torch.testing.assert_close(result, expected)
115-
11636

11737
if __name__ == "__main__":
11838
from torch._dynamo.test_case import run_tests

torch/_dynamo/variables/torch.py

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -746,37 +746,6 @@ def handle_full(self, tx, size, fill_value, **kwargs):
746746
tx, [size, result], kwargs
747747
)
748748

749-
@register(torch.ops.aten.view.default)
750-
def handle_view_default(self, tx, tensor, shape):
751-
from ..utils import proxy_args_kwargs
752-
from .builder import wrap_fx_proxy
753-
from .lists import ListVariable
754-
755-
def convert_tensor_to_scalar(item):
756-
if isinstance(item, TensorVariable):
757-
return TorchInGraphFunctionVariable(
758-
torch.ops.aten._local_scalar_dense
759-
).call_function(tx, [item], {})
760-
return item
761-
762-
if isinstance(shape, ListVariable):
763-
converted_items = [
764-
convert_tensor_to_scalar(item) for item in shape.items
765-
]
766-
shape = ListVariable(converted_items)
767-
elif isinstance(shape, TensorVariable):
768-
shape = convert_tensor_to_scalar(shape)
769-
770-
# Create proxy directly to avoid recursion
771-
return wrap_fx_proxy(
772-
tx,
773-
tx.output.create_proxy(
774-
"call_function",
775-
torch.ops.aten.view.default,
776-
*proxy_args_kwargs([tensor, shape], {}),
777-
),
778-
)
779-
780749
@register(torch._foreach_lerp_)
781750
def handle_inplace_foreach_lerp_scalar(
782751
_, tx: "InstructionTranslator", *args, **kwargs

0 commit comments

Comments
 (0)