97
97
"path" : "script"
98
98
},
99
99
"bert-base-uncased" : {
100
- "model" : "bert-base-uncased" ,
100
+ "model" : cm . BertModule () ,
101
101
"path" : "trace"
102
102
}
103
103
}
@@ -109,34 +109,7 @@ def get(n, m, manifest):
109
109
script_filename = n + '_scripted.jit.pt'
110
110
x = torch .ones ((1 , 3 , 300 , 300 )).cuda ()
111
111
if n == "bert-base-uncased" :
112
- # Prepare input for BERT case
113
- def prepare_bert_input ():
114
- enc = BertTokenizer .from_pretrained ("bert-base-uncased" )
115
- text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
116
- tokenized_text = enc .tokenize (text )
117
- masked_index = 8
118
- tokenized_text [masked_index ] = "[MASK]"
119
- indexed_tokens = enc .convert_tokens_to_ids (tokenized_text )
120
- segments_ids = [0 , 0 , 0 , 0 , 0 , 0 , 0 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ]
121
- tokens_tensor = torch .tensor ([indexed_tokens ])
122
- segments_tensors = torch .tensor ([segments_ids ])
123
- return [tokens_tensor , segments_tensors ]
124
-
125
- x = prepare_bert_input ()
126
- name = m ["model" ]
127
-
128
- config = BertConfig (
129
- vocab_size_or_config_json_file = 32000 ,
130
- hidden_size = 768 ,
131
- num_hidden_layers = 12 ,
132
- num_attention_heads = 12 ,
133
- intermediate_size = 3072 ,
134
- torchscript = True ,
135
- )
136
- m ["model" ] = BertModel (config )
137
- m ["model" ].eval ()
138
- m ["model" ] = BertModel .from_pretrained (name , torchscript = True )
139
- traced_model = torch .jit .trace (m ["model" ], x )
112
+ traced_model = m ["model" ]
140
113
torch .jit .save (traced_model , traced_filename )
141
114
manifest .update ({n : [traced_filename ]})
142
115
else :
@@ -182,6 +155,8 @@ def main():
182
155
# Check if Manifest file exists or is empty
183
156
if not os .path .exists (MANIFEST_FILE ) or os .stat (MANIFEST_FILE ).st_size == 0 :
184
157
manifest = {"version" : torch_version }
158
+
159
+ # Creating an empty manifest file for overwriting post setup
185
160
os .system ('touch {}' .format (MANIFEST_FILE ))
186
161
else :
187
162
manifest_exists = True
@@ -191,13 +166,14 @@ def main():
191
166
manifest = json .load (f )
192
167
if manifest ['version' ] == torch_version :
193
168
version_matches = True
194
- # Overwrite the manifest version as current torch version
195
- manifest ['version' ] = torch_version
196
169
else :
197
170
print ("Torch version: {} mismatches \
198
171
with manifest's version: {}. Re-downloading \
199
172
all models" .format (torch_version , manifest ['version' ]))
200
173
174
+ # Overwrite the manifest version as current torch version
175
+ manifest ['version' ] = torch_version
176
+
201
177
download_models (version_matches , manifest )
202
178
203
179
# Write updated manifest file to disk
0 commit comments