Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions py/torch_tensorrt/dynamo/_engine_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,20 @@ def check(self, hash: str, *args: Any, **kwargs: Any) -> Optional[UnpackedCacheH
Returns:
Optional[Tuple[bytes, List[str], List[str], CompilationSettings, Optional[Dict[Any, Any]]]]: The unpacked cache entry if found, None otherwise.
"""
packed_cache_info = self.load(hash, *args, **kwargs)
if self.exist(hash):
packed_cache_info = self.load(hash, *args, **kwargs)
if packed_cache_info:
return BaseEngineCache.unpack(packed_cache_info)
return None

if packed_cache_info:
return BaseEngineCache.unpack(packed_cache_info)
else:
return None
@abstractmethod
def exist(self, hash: str) -> bool:
"""Check if a cache entry exists for the given hash.

Args:
hash (str): hash value of the GraphModule
"""
pass

@abstractmethod
def save(self, hash: str, blob: bytes, *args: Any, **kwargs: Any) -> None:
Expand Down Expand Up @@ -306,6 +314,14 @@ def LRU() -> None:
else:
LRU()

def exist(self, hash: str) -> bool:
directory = os.path.join(self.engine_cache_dir, hash)
if os.path.exists(directory):
blob_path = os.path.join(directory, "blob.bin")
if os.path.exists(blob_path):
return True
return False

def save(self, hash: str, blob: bytes, *args: Any, **kwargs: Any) -> None:
blob_size = len(blob)
if blob_size > self.total_engine_cache_size:
Expand Down
26 changes: 14 additions & 12 deletions py/torch_tensorrt/dynamo/conversion/_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,18 +101,20 @@ def interpret_module_to_result(
and engine_cache is not None
):
hash_val = engine_cache.get_hash(module, inputs, settings)
engine_cache.insert(
hash_val,
(
serialized_engine,
interpreter_result.input_names,
interpreter_result.output_names,
inputs,
settings,
interpreter_result.weight_name_map,
interpreter_result.requires_output_allocator,
),
)
# only insert if the cache entry does not exist
if not engine_cache.exist(hash_val):
engine_cache.insert(
hash_val,
(
serialized_engine,
interpreter_result.input_names,
interpreter_result.output_names,
inputs,
settings,
interpreter_result.weight_name_map,
interpreter_result.requires_output_allocator,
),
)

serialized_interpreter_result = SerializedInterpreterResult(
serialized_engine=serialized_engine,
Expand Down
3 changes: 3 additions & 0 deletions tests/py/dynamo/models/test_engine_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ def __init__(

self.hashes = {}

def exist(self, hash: str) -> bool:
return hash in self.hashes

def save(
self,
hash: str,
Expand Down
Loading