1111import torchvision .models as models
1212from torch_tensorrt .dynamo .utils import COSINE_THRESHOLD , cosine_similarity
1313from transformers import BertModel
14- from transformers .utils .fx import symbolic_trace as transformers_trace
1514
1615from packaging .version import Version
1716
@@ -196,16 +195,18 @@ def test_resnet18_half(ir):
196195
197196
198197@unittest .skipIf (
199- torch .cuda .get_device_properties ( torch . cuda . current_device ()). major < 9 ,
200- "FP8 compilation in Torch-TRT is not supported on cards older than Hopper " ,
198+ torch .cuda .get_device_capability () < ( 8 , 9 ) ,
199+ "FP8 quantization requires compute capability 8.9 or later " ,
201200)
202201@unittest .skipIf (
203202 not importlib .util .find_spec ("modelopt" ),
204- reason = "ModelOpt is necessary to run this test" ,
203+ "ModelOpt is required to run this test" ,
205204)
206205@pytest .mark .unit
207206def test_base_fp8 (ir ):
208- import modelopt
207+ import modelopt .torch .quantization as mtq
208+ from modelopt .torch .quantization .utils import export_torch_mode
209+ from torch .export ._trace import _export
209210
210211 class SimpleNetwork (torch .nn .Module ):
211212 def __init__ (self ):
@@ -219,9 +220,6 @@ def forward(self, x):
219220 x = self .linear2 (x )
220221 return x
221222
222- import modelopt .torch .quantization as mtq
223- from modelopt .torch .quantization .utils import export_torch_mode
224-
225223 def calibrate_loop (model ):
226224 """Simple calibration function for testing."""
227225 model (input_tensor )
@@ -236,7 +234,7 @@ def calibrate_loop(model):
236234
237235 with torch .no_grad ():
238236 with export_torch_mode ():
239- exp_program = torch . export . export (model , (input_tensor ,))
237+ exp_program = _export (model , (input_tensor ,))
240238 trt_model = torchtrt .dynamo .compile (
241239 exp_program ,
242240 inputs = [input_tensor ],
@@ -247,7 +245,7 @@ def calibrate_loop(model):
247245 reuse_cached_engines = False ,
248246 )
249247 outputs_trt = trt_model (input_tensor )
250- assert torch .allclose (output_pyt , outputs_trt , rtol = 1e -3 , atol = 1e-2 )
248+ assert torch .allclose (output_pyt , outputs_trt , rtol = 5e -3 , atol = 1e-2 )
251249
252250
253251@unittest .skipIf (
@@ -258,7 +256,9 @@ def calibrate_loop(model):
258256)
259257@pytest .mark .unit
260258def test_base_int8 (ir ):
261- import modelopt
259+ import modelopt .torch .quantization as mtq
260+ from modelopt .torch .quantization .utils import export_torch_mode
261+ from torch .export ._trace import _export
262262
263263 class SimpleNetwork (torch .nn .Module ):
264264 def __init__ (self ):
@@ -272,9 +272,6 @@ def forward(self, x):
272272 x = self .linear2 (x )
273273 return x
274274
275- import modelopt .torch .quantization as mtq
276- from modelopt .torch .quantization .utils import export_torch_mode
277-
278275 def calibrate_loop (model ):
279276 """Simple calibration function for testing."""
280277 model (input_tensor )
@@ -289,8 +286,6 @@ def calibrate_loop(model):
289286
290287 with torch .no_grad ():
291288 with export_torch_mode ():
292- from torch .export ._trace import _export
293-
294289 exp_program = _export (model , (input_tensor ,))
295290 trt_model = torchtrt .dynamo .compile (
296291 exp_program ,
@@ -302,4 +297,4 @@ def calibrate_loop(model):
302297 reuse_cached_engines = False ,
303298 )
304299 outputs_trt = trt_model (input_tensor )
305- assert torch .allclose (output_pyt , outputs_trt , rtol = 1e -3 , atol = 1e-2 )
300+ assert torch .allclose (output_pyt , outputs_trt , rtol = 5e -3 , atol = 1e-2 )
0 commit comments