Skip to content

Commit 503f320

Browse files
committed
ready for review
1 parent 35d5861 commit 503f320

File tree

3 files changed

+12
-32
lines changed

3 files changed

+12
-32
lines changed

examples/apps/flux_demo.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,6 @@ def compile_model(
6262
torch_dtype=torch.float16,
6363
).to(torch.float16)
6464

65-
# pipe.transformer = FluxTransformer2DModel(
66-
# num_layers=28, num_single_layers=12, guidance_embeds=True
67-
# ).to(torch.float16)
68-
6965
if args.low_vram_mode:
7066
pipe.enable_model_cpu_offload()
7167
else:

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -694,7 +694,7 @@ def compile(
694694
# Move the weights in the state_dict to CPU
695695
if offload_module_to_cpu:
696696
deallocate_module(gm, delete_module=False)
697-
# deallocate_module(exported_program.module(), delete_module=False)
697+
deallocate_module(exported_program.module(), delete_module=False)
698698
logger.info(
699699
"The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False"
700700
)

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 11 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -591,13 +591,11 @@ def _save_weight_mapping(self) -> None:
591591
torch.cuda.empty_cache()
592592

593593
@needs_refit # type: ignore[misc]
594-
def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> None:
594+
def _insert_engine_to_cache(self, hash_val: str, engine: bytes) -> None:
595+
serialized_engine = engine.serialize()
595596
# TODO: @Evan is waiting for TRT's feature to cache the weight-stripped engine
596597
# if not self.compilation_settings.strip_engine_weights:
597598
# # set EXCLUDE_WEIGHTS flag to strip weights
598-
# runtime = trt.Runtime(TRT_LOGGER)
599-
# engine = runtime.deserialize_cuda_engine(serialized_engine)
600-
601599
# serialization_config = engine.create_serialization_config()
602600
# serialization_config.set_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS)
603601
# serialized_engine = engine.serialize_with_config(
@@ -731,10 +729,6 @@ def run(
731729
if interpreter_result is not None: # hit the cache
732730
return interpreter_result # type: ignore[no-any-return]
733731

734-
import psutil
735-
736-
print(psutil.Process().memory_info().rss / 1024 / 1024, "MB")
737-
# breakpoint()
738732
self._construct_trt_network_def()
739733

740734
if not self.compilation_settings.immutable_weights:
@@ -753,14 +747,11 @@ def run(
753747
self._create_timing_cache(
754748
builder_config, self.compilation_settings.timing_cache_path
755749
)
756-
import psutil
757-
758-
print(psutil.Process().memory_info().rss / 1024 / 1024, "MB")
759-
# breakpoint()
760750

761751
cuda_engine = self.builder.build_engine_with_config(
762752
self.ctx.net, builder_config
763753
)
754+
assert cuda_engine
764755

765756
_LOGGER.info(
766757
f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}"
@@ -772,17 +763,13 @@ def run(
772763
)
773764

774765
# Engine caching only for refittable engines
775-
# if (
776-
# not self.compilation_settings.immutable_weights
777-
# and self.compilation_settings.cache_built_engines
778-
# and self.engine_cache is not None
779-
# ):
780-
# self._insert_engine_to_cache(hash_val, serialized_engine)
781-
782-
print("After build_engine_with_config")
783-
print(psutil.Process().memory_info().rss / 1024 / 1024, "MB")
784-
# breakpoint()
785-
assert cuda_engine
766+
if (
767+
not self.compilation_settings.immutable_weights
768+
and self.compilation_settings.cache_built_engines
769+
and self.engine_cache is not None
770+
):
771+
self._insert_engine_to_cache(hash_val, cuda_engine)
772+
786773
if self.compilation_settings.use_python_runtime:
787774
return TRTInterpreterResult(
788775
cuda_engine,
@@ -792,16 +779,13 @@ def run(
792779
self.ctx.requires_output_allocator,
793780
)
794781
else:
795-
print(psutil.Process().memory_info().rss / 1024 / 1024, "MB")
796-
# breakpoint()
797782
serialized_engine = cuda_engine.serialize()
798783
_LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory")
799784

800785
with io.BytesIO() as engine_bytes:
801786
engine_bytes.write(serialized_engine)
802787
engine_str = engine_bytes.getvalue()
803-
print(psutil.Process().memory_info().rss / 1024 / 1024, "MB")
804-
# breakpoint()
788+
805789
return TRTInterpreterResult(
806790
engine_str,
807791
self._input_names,

0 commit comments

Comments
 (0)