Skip to content

Commit 33046eb

Browse files
committed
fix: Moved BERT module setup to custom_models
Signed-off-by: Anurag Dixit <[email protected]>
1 parent 2e1764a commit 33046eb

File tree

2 files changed

+33
-32
lines changed

2 files changed

+33
-32
lines changed

tests/modules/custom_models.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,4 +85,29 @@ def forward(self, x):
8585
x = self.conv1(x)
8686
return x
8787

88-
88+
89+
def BertModule():
90+
model_name = "bert-base-uncased"
91+
enc = BertTokenizer.from_pretrained(model_name)
92+
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
93+
tokenized_text = enc.tokenize(text)
94+
masked_index = 8
95+
tokenized_text[masked_index] = "[MASK]"
96+
indexed_tokens = enc.convert_tokens_to_ids(tokenized_text)
97+
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]
98+
tokens_tensor = torch.tensor([indexed_tokens])
99+
segments_tensors = torch.tensor([segments_ids])
100+
config = BertConfig(
101+
vocab_size_or_config_json_file=32000,
102+
hidden_size=768,
103+
num_hidden_layers=12,
104+
num_attention_heads=12,
105+
intermediate_size=3072,
106+
torchscript=True,
107+
)
108+
model = BertModel(config)
109+
model.eval()
110+
model = BertModel.from_pretrained(model_name, torchscript=True)
111+
traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors])
112+
return traced_model
113+

tests/modules/hub.py

Lines changed: 7 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@
9797
"path": "script"
9898
},
9999
"bert-base-uncased": {
100-
"model": "bert-base-uncased",
100+
"model": cm.BertModule(),
101101
"path": "trace"
102102
}
103103
}
@@ -109,34 +109,7 @@ def get(n, m, manifest):
109109
script_filename = n + '_scripted.jit.pt'
110110
x = torch.ones((1, 3, 300, 300)).cuda()
111111
if n == "bert-base-uncased":
112-
# Prepare input for BERT case
113-
def prepare_bert_input():
114-
enc = BertTokenizer.from_pretrained("bert-base-uncased")
115-
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
116-
tokenized_text = enc.tokenize(text)
117-
masked_index = 8
118-
tokenized_text[masked_index] = "[MASK]"
119-
indexed_tokens = enc.convert_tokens_to_ids(tokenized_text)
120-
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]
121-
tokens_tensor = torch.tensor([indexed_tokens])
122-
segments_tensors = torch.tensor([segments_ids])
123-
return [tokens_tensor, segments_tensors]
124-
125-
x = prepare_bert_input()
126-
name = m["model"]
127-
128-
config = BertConfig(
129-
vocab_size_or_config_json_file=32000,
130-
hidden_size=768,
131-
num_hidden_layers=12,
132-
num_attention_heads=12,
133-
intermediate_size=3072,
134-
torchscript=True,
135-
)
136-
m["model"] = BertModel(config)
137-
m["model"].eval()
138-
m["model"] = BertModel.from_pretrained(name, torchscript=True)
139-
traced_model = torch.jit.trace(m["model"], x)
112+
traced_model = m["model"]
140113
torch.jit.save(traced_model, traced_filename)
141114
manifest.update({n : [traced_filename]})
142115
else:
@@ -182,6 +155,8 @@ def main():
182155
# Check if Manifest file exists or is empty
183156
if not os.path.exists(MANIFEST_FILE) or os.stat(MANIFEST_FILE).st_size == 0:
184157
manifest = {"version": torch_version}
158+
159+
# Creating an empty manifest file for overwriting post setup
185160
os.system('touch {}'.format(MANIFEST_FILE))
186161
else:
187162
manifest_exists = True
@@ -191,13 +166,14 @@ def main():
191166
manifest = json.load(f)
192167
if manifest['version'] == torch_version:
193168
version_matches = True
194-
# Overwrite the manifest version as current torch version
195-
manifest['version'] = torch_version
196169
else:
197170
print("Torch version: {} mismatches \
198171
with manifest's version: {}. Re-downloading \
199172
all models".format(torch_version, manifest['version']))
200173

174+
# Overwrite the manifest version as current torch version
175+
manifest['version'] = torch_version
176+
201177
download_models(version_matches, manifest)
202178

203179
# Write updated manifest file to disk

0 commit comments

Comments
 (0)