Skip to content

Commit c57aa04

Browse files
committed
chore: bug fixes for refit tests, restructure CI tests
1 parent 7cba8f2 commit c57aa04

File tree

4 files changed

+30
-244
lines changed

4 files changed

+30
-244
lines changed

.github/workflows/build-test-linux.yml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,12 @@ jobs:
172172
cd tests/py
173173
python -m pip install -r requirements.txt
174174
cd dynamo
175-
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_models_export.xml --ir dynamo models/
176-
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_models_export.xml test_modelopt_models.py
175+
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_models.xml --ir dynamo models/test_models.py
176+
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_models_dynamic.xml --ir dynamo models/test_dyn_models.py
177+
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/engine_cache.xml --ir dynamo models/test_engine_cache.py
178+
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/model_refit.xml --ir dynamo models/test_model_refit.py
179+
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/modelopt_models.xml --ir dynamo models/test_modelopt_models.py
180+
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/weight_stripped_engine.xml --ir dynamo models/test_weight_stripped_engine.py
177181
popd
178182
179183
tests-py-dynamo-serde:
@@ -206,6 +210,7 @@ jobs:
206210
cd dynamo
207211
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/export_serde_test_results.xml --ir dynamo models/test_export_serde.py
208212
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/reexport_test_results.xml --ir dynamo models/test_reexport.py
213+
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/export_kwargs_serde_test_results.xml --ir dynamo models/test_export_kwargs_serde.py
209214
popd
210215
211216
tests-py-torch-compile-be:

py/torch_tensorrt/dynamo/_refit.py

Lines changed: 23 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
logger = logging.getLogger(__name__)
4949

5050

51-
@needs_refit
51+
@needs_refit # type: ignore
5252
def construct_refit_mapping(
5353
module: torch.fx.GraphModule,
5454
inputs: Sequence[Input],
@@ -110,7 +110,7 @@ def construct_refit_mapping(
110110
return weight_map
111111

112112

113-
@needs_refit
113+
@needs_refit # type: ignore
114114
def construct_refit_mapping_from_weight_name_map(
115115
weight_name_map: dict[Any, Any],
116116
state_dict: dict[Any, Any],
@@ -141,7 +141,7 @@ def construct_refit_mapping_from_weight_name_map(
141141
return engine_weight_map
142142

143143

144-
@needs_refit
144+
@needs_refit # type: ignore
145145
def _refit_single_trt_engine_with_gm(
146146
new_gm: torch.fx.GraphModule,
147147
old_engine: trt.ICudaEngine,
@@ -153,12 +153,12 @@ def _refit_single_trt_engine_with_gm(
153153
Refit a TensorRT Engine in place
154154
"""
155155

156-
with unset_fake_temporarily():
157-
refitted = set()
158-
torch_device = get_model_device(new_gm)
159-
refitter = trt.Refitter(old_engine, TRT_LOGGER)
160-
weight_list = refitter.get_all_weights()
156+
refitted = set()
157+
torch_device = get_model_device(new_gm)
158+
refitter = trt.Refitter(old_engine, TRT_LOGGER)
159+
weight_list = refitter.get_all_weights()
161160

161+
with unset_fake_temporarily():
162162
if weight_name_map:
163163
# Get the refitting mapping
164164
trt_wt_location = (
@@ -185,41 +185,21 @@ def _refit_single_trt_engine_with_gm(
185185
trt_dtype,
186186
)
187187

188-
constant_mapping: dict[str, Any] = weight_name_map.pop(
189-
"constant_mapping", {}
190-
) # type: ignore
191-
mapping = construct_refit_mapping_from_weight_name_map(
192-
weight_name_map, new_gm.state_dict()
193-
)
194-
constant_mapping_with_type = {}
195-
196-
for constant_name, val in constant_mapping.items():
197-
np_weight_type = val.dtype
198-
val_tensor = torch.from_numpy(val).cuda()
199-
trt_dtype = dtype.try_from(np_weight_type).to(trt.DataType)
200-
torch_dtype = dtype.try_from(np_weight_type).to(torch.dtype)
201-
constant_mapping_with_type[constant_name] = (
202-
val_tensor.clone().reshape(-1).contiguous().to(torch_dtype),
203-
trt_dtype,
204-
)
188+
mapping.update(constant_mapping_with_type)
205189

206-
mapping.update(constant_mapping_with_type)
207-
208-
for layer_name in weight_list:
209-
if layer_name not in mapping:
210-
logger.warning(f"{layer_name} is not found in weight mapping.")
211-
continue
212-
# Use Numpy to create weights
213-
weight, weight_dtype = mapping[layer_name]
214-
trt_wt_tensor = trt.Weights(
215-
weight_dtype, weight.data_ptr(), torch.numel(weight)
216-
)
217-
refitter.set_named_weights(
218-
layer_name, trt_wt_tensor, trt_wt_location
219-
)
220-
assert (
221-
len(refitter.get_missing_weights()) == 0
222-
), "Fast refitting failed due to incomplete mapping"
190+
for layer_name in weight_list:
191+
if layer_name not in mapping:
192+
logger.warning(f"{layer_name} is not found in weight mapping.")
193+
continue
194+
# Use Numpy to create weights
195+
weight, weight_dtype = mapping[layer_name]
196+
trt_wt_tensor = trt.Weights(
197+
weight_dtype, weight.data_ptr(), torch.numel(weight)
198+
)
199+
refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location)
200+
assert (
201+
len(refitter.get_missing_weights()) == 0
202+
), "Fast refitting failed due to incomplete mapping"
223203

224204
else:
225205
mapping = construct_refit_mapping(new_gm, input_list, settings)
@@ -241,7 +221,7 @@ def _refit_single_trt_engine_with_gm(
241221
raise AssertionError("Refitting failed.")
242222

243223

244-
@needs_refit
224+
@needs_refit # type: ignore
245225
def refit_module_weights(
246226
compiled_module: torch.fx.GraphModule | ExportedProgram,
247227
new_weight_module: ExportedProgram,

tests/py/dynamo/models/test_models_export.py

Lines changed: 0 additions & 199 deletions
This file was deleted.

0 commit comments

Comments
 (0)