1- # Copyright 2021 MONAI Consortium
1+ # Copyright 2021-2025 MONAI Consortium
22# Licensed under the Apache License, Version 2.0 (the "License");
33# you may not use this file except in compliance with the License.
44# You may obtain a copy of the License at
@@ -24,11 +24,12 @@ class TorchScriptModel(Model):
2424 """Represents TorchScript model.
2525
2626 TorchScript serialization format (TorchScript model file) is created by torch.jit.save() method and
27- the serialized model (which usually has .pt or .pth extension) is a ZIP archive containing many files .
27+ the serialized model (which usually has .pt or .pth extension) is a ZIP archive.
2828 (https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/docs/serialization.md)
2929
30- We consider that the model is a torchscript model if its unzipped archive contains files named 'data.pkl' and
31- 'constants.pkl', and folders named 'code' and 'data'.
30+ We identify a file as a TorchScript model if its unzipped archive contains a 'code/' directory
31+ and a 'data.pkl' file. For tensor constants, it may contain either a 'constants.pkl' file (older format)
32+ or a 'constants/' directory (newer format).
3233
3334 When predictor property is accessed or the object is called (__call__), the model is loaded in `evaluation mode`
3435 from the serialized model file (if it is not loaded yet) and the model is ready to be used.
@@ -85,31 +86,38 @@ def train(self, mode: bool = True) -> "TorchScriptModel":
8586
8687 @classmethod
8788 def accept (cls , path : str ):
88- prefix_code = False
89- prefix_data = False
90- prefix_constants_pkl = False
91- prefix_data_pkl = False
89+ # These are the files and directories we expect to find in a TorchScript zip archive.
90+ # See: https://github.com/pytorch/pytorch/blob/main/torch/csrc/jit/docs/serialization.md
91+ has_code_dir = False
92+ has_constants_dir = False
93+ has_constants_pkl = False
94+ has_data_pkl = False
9295
9396 if not os .path .isfile (path ):
9497 return False , None
9598
9699 try :
97- zip_file = ZipFile (path )
98- for data in zip_file .filelist :
99- file_name = data .filename
100- pivot = file_name .find ("/" )
101- if pivot != - 1 and not prefix_code and file_name [pivot :].startswith ("/code/" ):
102- prefix_code = True
103- if pivot != - 1 and not prefix_data and file_name [pivot :].startswith ("/data/" ):
104- prefix_data = True
105- if pivot != - 1 and not prefix_constants_pkl and file_name [pivot :] == "/constants.pkl" :
106- prefix_constants_pkl = True
107- if pivot != - 1 and not prefix_data_pkl and file_name [pivot :] == "/data.pkl" :
108- prefix_data_pkl = True
109- except BadZipFile :
100+ with ZipFile (path ) as zip_file :
101+ # Top-level directory name in the zip file (e.g., 'model_name/')
102+ top_level_dir = ""
103+ if "/" in zip_file .filelist [0 ].filename :
104+ top_level_dir = zip_file .filelist [0 ].filename .split ("/" , 1 )[0 ] + "/"
105+
106+ filenames = {f .filename for f in zip_file .filelist }
107+
108+ # Check for required files and directories
109+ has_data_pkl = (top_level_dir + "data.pkl" ) in filenames
110+ has_code_dir = any (f .startswith (top_level_dir + "code/" ) for f in filenames )
111+
112+ # Check for either constants.pkl (older format) or constants/ (newer format)
113+ has_constants_pkl = (top_level_dir + "constants.pkl" ) in filenames
114+ has_constants_dir = any (f .startswith (top_level_dir + "constants/" ) for f in filenames )
115+
116+ except (BadZipFile , IndexError ):
110117 return False , None
111118
112- if prefix_code and prefix_data and prefix_constants_pkl and prefix_data_pkl :
119+ # A valid TorchScript model must have code/, data.pkl, and either constants.pkl or constants/
120+ if has_code_dir and has_data_pkl and (has_constants_pkl or has_constants_dir ):
113121 return True , cls .model_type
114122
115123 return False , None
0 commit comments