diff --git a/monai/deploy/core/models/torch_model.py b/monai/deploy/core/models/torch_model.py index 81a3f9ed..6352734e 100644 --- a/monai/deploy/core/models/torch_model.py +++ b/monai/deploy/core/models/torch_model.py @@ -1,4 +1,4 @@ -# Copyright 2021 MONAI Consortium +# Copyright 2021-2025 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -24,11 +24,12 @@ class TorchScriptModel(Model): """Represents TorchScript model. TorchScript serialization format (TorchScript model file) is created by torch.jit.save() method and - the serialized model (which usually has .pt or .pth extension) is a ZIP archive containing many files. + the serialized model (which usually has .pt or .pth extension) is a ZIP archive. (https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/docs/serialization.md) - We consider that the model is a torchscript model if its unzipped archive contains files named 'data.pkl' and - 'constants.pkl', and folders named 'code' and 'data'. + We identify a file as a TorchScript model if its unzipped archive contains a 'code/' directory + and a 'data.pkl' file. For tensor constants, it may contain either a 'constants.pkl' file (older format) + or a 'constants/' directory (newer format). When predictor property is accessed or the object is called (__call__), the model is loaded in `evaluation mode` from the serialized model file (if it is not loaded yet) and the model is ready to be used. @@ -85,31 +86,38 @@ def train(self, mode: bool = True) -> "TorchScriptModel": @classmethod def accept(cls, path: str): - prefix_code = False - prefix_data = False - prefix_constants_pkl = False - prefix_data_pkl = False + # These are the files and directories we expect to find in a TorchScript zip archive. + # See: https://github.com/pytorch/pytorch/blob/main/torch/csrc/jit/docs/serialization.md + has_code_dir = False + has_constants_dir = False + has_constants_pkl = False + has_data_pkl = False if not os.path.isfile(path): return False, None try: - zip_file = ZipFile(path) - for data in zip_file.filelist: - file_name = data.filename - pivot = file_name.find("/") - if pivot != -1 and not prefix_code and file_name[pivot:].startswith("/code/"): - prefix_code = True - if pivot != -1 and not prefix_data and file_name[pivot:].startswith("/data/"): - prefix_data = True - if pivot != -1 and not prefix_constants_pkl and file_name[pivot:] == "/constants.pkl": - prefix_constants_pkl = True - if pivot != -1 and not prefix_data_pkl and file_name[pivot:] == "/data.pkl": - prefix_data_pkl = True - except BadZipFile: + with ZipFile(path) as zip_file: + # Top-level directory name in the zip file (e.g., 'model_name/') + top_level_dir = "" + if "/" in zip_file.filelist[0].filename: + top_level_dir = zip_file.filelist[0].filename.split("/", 1)[0] + "/" + + filenames = {f.filename for f in zip_file.filelist} + + # Check for required files and directories + has_data_pkl = (top_level_dir + "data.pkl") in filenames + has_code_dir = any(f.startswith(top_level_dir + "code/") for f in filenames) + + # Check for either constants.pkl (older format) or constants/ (newer format) + has_constants_pkl = (top_level_dir + "constants.pkl") in filenames + has_constants_dir = any(f.startswith(top_level_dir + "constants/") for f in filenames) + + except (BadZipFile, IndexError): return False, None - if prefix_code and prefix_data and prefix_constants_pkl and prefix_data_pkl: + # A valid TorchScript model must have code/, data.pkl, and either constants.pkl or constants/ + if has_code_dir and has_data_pkl and (has_constants_pkl or has_constants_dir): return True, cls.model_type return False, None