@@ -33,86 +33,6 @@ def f(t, _n):
3333 t = torch .tensor ([2 , 4 ], dtype = torch .int32 )
3434 f (t , 8 )
3535
36- def test_view_with_tensor_shape_params (self ):
37- # Test for issue #156720: aten.view.default with tensor shape parameters
38- class TestModel (torch .nn .Module ):
39- def forward (self , x , shape_params ):
40- return torch .ops .aten .view .default (x , shape_params )
41-
42- x = torch .randn (24 )
43- shape_params = [
44- torch .tensor (2 , dtype = torch .int32 ),
45- torch .tensor (3 , dtype = torch .int32 ),
46- torch .tensor (4 , dtype = torch .int32 ),
47- ]
48-
49- model = TestModel ()
50- expected = model (x , shape_params )
51-
52- compiled_model = torch .compile (model , backend = "eager" )
53- result = compiled_model (x , shape_params )
54-
55- torch .testing .assert_close (result , expected )
56-
57- def test_tensor_view_with_tensor_shape_params (self ):
58- # Test tensor.view() method with tensor shape parameters (list version)
59- class TestModel (torch .nn .Module ):
60- def forward (self , x , shape_params ):
61- return x .view (shape_params )
62-
63- x = torch .randn (24 )
64- shape_params = (
65- torch .tensor (2 , dtype = torch .int32 ),
66- torch .tensor (3 , dtype = torch .int32 ),
67- torch .tensor (4 , dtype = torch .int32 ),
68- )
69-
70- model = TestModel ()
71- expected = model (x , shape_params )
72-
73- compiled_model = torch .compile (model , backend = "eager" )
74- result = compiled_model (x , shape_params )
75-
76- torch .testing .assert_close (result , expected )
77-
78- def test_tensor_view_with_tensor_args (self ):
79- # Test tensor.view() method with individual tensor arguments
80- class TestModel (torch .nn .Module ):
81- def forward (self , x , dim1 , dim2 , dim3 ):
82- return x .view (dim1 , dim2 , dim3 )
83-
84- x = torch .randn (24 )
85- dim1 = torch .tensor (2 , dtype = torch .int32 )
86- dim2 = torch .tensor (3 , dtype = torch .int32 )
87- dim3 = torch .tensor (4 , dtype = torch .int32 )
88-
89- model = TestModel ()
90- expected = model (x , dim1 , dim2 , dim3 )
91-
92- compiled_model = torch .compile (model , backend = "eager" )
93- result = compiled_model (x , dim1 , dim2 , dim3 )
94-
95- torch .testing .assert_close (result , expected )
96-
97- def test_torch_reshape_with_tensor_shape_params (self ):
98- # Test torch.reshape() function with tensor shape parameters
99- def test_fn (x , shape_params ):
100- return torch .reshape (x , shape_params )
101-
102- x = torch .randn (24 )
103- shape_params = [
104- torch .tensor (2 , dtype = torch .int32 ),
105- torch .tensor (3 , dtype = torch .int32 ),
106- torch .tensor (4 , dtype = torch .int32 ),
107- ]
108-
109- expected = test_fn (x , shape_params )
110-
111- compiled_fn = torch .compile (test_fn , backend = "eager" )
112- result = compiled_fn (x , shape_params )
113-
114- torch .testing .assert_close (result , expected )
115-
11636
11737if __name__ == "__main__" :
11838 from torch ._dynamo .test_case import run_tests
0 commit comments