@@ -158,118 +158,6 @@ def forward(self, x):
158158 torch ._dynamo .reset ()
159159
160160
161- class TestLowerLinear (TestCase ):
162- @unittest .skip (
163- "This test has threshold failures. This is tracked at https://github.com/pytorch/TensorRT/issues/2715" ,
164- )
165- def test_lower_linear (self ):
166- class Linear (torch .nn .Module ):
167- def forward (self , input , weight , bias ):
168- out = torch .ops .aten .linear .default (input , weight , bias )
169- return out
170-
171- inputs = [
172- torch .rand ((3 , 32 )).cuda (),
173- torch .rand ((64 , 32 )).cuda (),
174- torch .rand ((64 ,)).cuda (),
175- ]
176-
177- fx_graph = torch .fx .symbolic_trace (Linear ())
178- expected_ops = {torch .ops .aten .linear .default }
179- unexpected_ops = {
180- torch .ops .aten .permute .default ,
181- torch .ops .aten .addmm .default ,
182- }
183-
184- unexpected_ops_seen , expected_ops_unseen = lower_graph_testing (
185- fx_graph ,
186- inputs ,
187- expected_ops = expected_ops ,
188- unexpected_ops = unexpected_ops ,
189- min_block_size = 1 ,
190- )
191-
192- self .assertEqual (
193- len (unexpected_ops_seen ),
194- 0 ,
195- f"The following unexpected ops were encountered: { unexpected_ops_seen } " ,
196- )
197-
198- self .assertEqual (
199- len (expected_ops_unseen ),
200- 0 ,
201- f"The following expected ops were not encountered: { expected_ops_unseen } " ,
202- )
203- torch ._dynamo .reset ()
204-
205- # Validate that the results between Torch and Torch-TRT are similar
206- optimized_model = torch_tensorrt .compile (
207- fx_graph ,
208- "torch_compile" ,
209- inputs ,
210- min_block_size = 1 ,
211- pass_through_build_failures = True ,
212- )
213- optimized_model_results = torch .cat (
214- [tensor .detach ().cpu () for tensor in optimized_model (* inputs )]
215- )
216- torch_model_results = torch .cat (
217- [tensor .detach ().cpu () for tensor in fx_graph (* inputs )]
218- )
219-
220- max_diff = float (
221- torch .max (torch .abs (optimized_model_results - torch_model_results ))
222- )
223-
224- self .assertAlmostEqual (
225- max_diff ,
226- 0 ,
227- DECIMALS_OF_AGREEMENT ,
228- msg = f"Linear TRT outputs don't match with the original model." ,
229- )
230- torch ._dynamo .reset ()
231-
232- def test_lower_linear_batch (self ):
233- class Linear (torch .nn .Module ):
234- def forward (self , input , weight , bias ):
235- out = torch .ops .aten .linear .default (input , weight , bias )
236- return out
237-
238- inputs = [
239- torch .rand ((2 , 2 , 32 )).cuda (),
240- torch .rand ((64 , 32 )).cuda (),
241- torch .rand ((64 ,)).cuda (),
242- ]
243-
244- fx_graph = torch .fx .symbolic_trace (Linear ())
245-
246- # Validate that the results between Torch and Torch-TRT are similar
247- optimized_model = torch_tensorrt .compile (
248- fx_graph ,
249- "torch_compile" ,
250- inputs ,
251- min_block_size = 1 ,
252- pass_through_build_failures = True ,
253- )
254- optimized_model_results = torch .cat (
255- [tensor .detach ().cpu () for tensor in optimized_model (* inputs )]
256- )
257- torch_model_results = torch .cat (
258- [tensor .detach ().cpu () for tensor in fx_graph (* inputs )]
259- )
260-
261- max_diff = float (
262- torch .max (torch .abs (optimized_model_results - torch_model_results ))
263- )
264- self .assertAlmostEqual (
265- max_diff ,
266- 0 ,
267- DECIMALS_OF_AGREEMENT ,
268- msg = f"Linear TRT outputs don't match with the original model." ,
269- )
270- torch ._dynamo .reset ()
271-
272-
273161class TestLowerViewToReshape (TestCase ):
274162 def test_view_to_reshape (self ):
275163 class ViewToReshape (torch .nn .Module ):
0 commit comments