diff --git a/py/torch_tensorrt/dynamo/_engine_cache.py b/py/torch_tensorrt/dynamo/_engine_cache.py index a6d9a1face..60a8eb94f6 100644 --- a/py/torch_tensorrt/dynamo/_engine_cache.py +++ b/py/torch_tensorrt/dynamo/_engine_cache.py @@ -187,12 +187,20 @@ def check(self, hash: str, *args: Any, **kwargs: Any) -> Optional[UnpackedCacheH Returns: Optional[Tuple[bytes, List[str], List[str], CompilationSettings, Optional[Dict[Any, Any]]]]: The unpacked cache entry if found, None otherwise. """ - packed_cache_info = self.load(hash, *args, **kwargs) + if self.exist(hash): + packed_cache_info = self.load(hash, *args, **kwargs) + if packed_cache_info: + return BaseEngineCache.unpack(packed_cache_info) + return None - if packed_cache_info: - return BaseEngineCache.unpack(packed_cache_info) - else: - return None + @abstractmethod + def exist(self, hash: str) -> bool: + """Check if a cache entry exists for the given hash. + + Args: + hash (str): hash value of the GraphModule + """ + pass @abstractmethod def save(self, hash: str, blob: bytes, *args: Any, **kwargs: Any) -> None: @@ -306,6 +314,14 @@ def LRU() -> None: else: LRU() + def exist(self, hash: str) -> bool: + directory = os.path.join(self.engine_cache_dir, hash) + if os.path.exists(directory): + blob_path = os.path.join(directory, "blob.bin") + if os.path.exists(blob_path): + return True + return False + def save(self, hash: str, blob: bytes, *args: Any, **kwargs: Any) -> None: blob_size = len(blob) if blob_size > self.total_engine_cache_size: diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index 76926107a4..c847c6664f 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -101,18 +101,20 @@ def interpret_module_to_result( and engine_cache is not None ): hash_val = engine_cache.get_hash(module, inputs, settings) - engine_cache.insert( - hash_val, - ( - serialized_engine, - interpreter_result.input_names, - interpreter_result.output_names, - inputs, - settings, - interpreter_result.weight_name_map, - interpreter_result.requires_output_allocator, - ), - ) + # only insert if the cache entry does not exist + if not engine_cache.exist(hash_val): + engine_cache.insert( + hash_val, + ( + serialized_engine, + interpreter_result.input_names, + interpreter_result.output_names, + inputs, + settings, + interpreter_result.weight_name_map, + interpreter_result.requires_output_allocator, + ), + ) serialized_interpreter_result = SerializedInterpreterResult( serialized_engine=serialized_engine, diff --git a/tests/py/dynamo/models/test_engine_cache.py b/tests/py/dynamo/models/test_engine_cache.py index 5e310900aa..b4afc5acc2 100644 --- a/tests/py/dynamo/models/test_engine_cache.py +++ b/tests/py/dynamo/models/test_engine_cache.py @@ -31,6 +31,9 @@ def __init__( self.hashes = {} + def exist(self, hash: str) -> bool: + return hash in self.hashes + def save( self, hash: str,