@@ -111,24 +111,25 @@ def get(n, m, manifest):
111
111
if n == "bert-base-uncased" :
112
112
traced_model = m ["model" ]
113
113
torch .jit .save (traced_model , traced_filename )
114
- manifest .update ({n : [traced_filename ]})
114
+ manifest .update ({n : [traced_filename ]})
115
115
else :
116
116
m ["model" ] = m ["model" ].eval ().cuda ()
117
117
if m ["path" ] == "both" or m ["path" ] == "trace" :
118
118
trace_model = torch .jit .trace (m ["model" ], [x ])
119
119
torch .jit .save (trace_model , traced_filename )
120
- manifest .update ({n : [traced_filename ]})
120
+ manifest .update ({n : [traced_filename ]})
121
121
if m ["path" ] == "both" or m ["path" ] == "script" :
122
122
script_model = torch .jit .script (m ["model" ])
123
123
torch .jit .save (script_model , script_filename )
124
124
if n in manifest .keys ():
125
125
files = list (manifest [n ]) if type (manifest [n ]) != list else manifest [n ]
126
126
files .append (script_filename )
127
- manifest .update ({n : files })
127
+ manifest .update ({n : files })
128
128
else :
129
129
manifest .update ({n : [script_filename ]})
130
130
return manifest
131
131
132
+
132
133
def download_models (version_matches , manifest ):
133
134
# Download all models if torch version is different than model version
134
135
if not version_matches :
@@ -142,8 +143,8 @@ def download_models(version_matches, manifest):
142
143
if (m ["path" ] == "both" and os .path .exists (scripted_filename ) and os .path .exists (traced_filename )) or \
143
144
(m ["path" ] == "script" and os .path .exists (scripted_filename )) or \
144
145
(m ["path" ] == "trace" and os .path .exists (traced_filename )):
145
- print ("Skipping {} " .format (n ))
146
- continue
146
+ print ("Skipping {} " .format (n ))
147
+ continue
147
148
manifest = get (n , m , manifest )
148
149
149
150
@@ -184,4 +185,5 @@ def main():
184
185
f .write (record )
185
186
f .truncate ()
186
187
187
- main ()
188
+
189
+ main ()
0 commit comments