Skip to content

Commit 4c2340d

Browse files
committed
chore: Applying lint
Signed-off-by: Anurag Dixit <[email protected]>
1 parent 33046eb commit 4c2340d

File tree

2 files changed

+16
-14
lines changed

2 files changed

+16
-14
lines changed

tests/modules/custom_models.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from transformers import BertModel, BertTokenizer, BertConfig
44
import torch.nn.functional as F
55

6+
67
# Sample Pool Model (for testing plugin serialization)
78
class Pool(nn.Module):
89

@@ -98,16 +99,15 @@ def BertModule():
9899
tokens_tensor = torch.tensor([indexed_tokens])
99100
segments_tensors = torch.tensor([segments_ids])
100101
config = BertConfig(
101-
vocab_size_or_config_json_file=32000,
102-
hidden_size=768,
103-
num_hidden_layers=12,
104-
num_attention_heads=12,
105-
intermediate_size=3072,
106-
torchscript=True,
107-
)
102+
vocab_size_or_config_json_file=32000,
103+
hidden_size=768,
104+
num_hidden_layers=12,
105+
num_attention_heads=12,
106+
intermediate_size=3072,
107+
torchscript=True,
108+
)
108109
model = BertModel(config)
109110
model.eval()
110111
model = BertModel.from_pretrained(model_name, torchscript=True)
111112
traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors])
112113
return traced_model
113-

tests/modules/hub.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -111,24 +111,25 @@ def get(n, m, manifest):
111111
if n == "bert-base-uncased":
112112
traced_model = m["model"]
113113
torch.jit.save(traced_model, traced_filename)
114-
manifest.update({n : [traced_filename]})
114+
manifest.update({n: [traced_filename]})
115115
else:
116116
m["model"] = m["model"].eval().cuda()
117117
if m["path"] == "both" or m["path"] == "trace":
118118
trace_model = torch.jit.trace(m["model"], [x])
119119
torch.jit.save(trace_model, traced_filename)
120-
manifest.update({n : [traced_filename]})
120+
manifest.update({n: [traced_filename]})
121121
if m["path"] == "both" or m["path"] == "script":
122122
script_model = torch.jit.script(m["model"])
123123
torch.jit.save(script_model, script_filename)
124124
if n in manifest.keys():
125125
files = list(manifest[n]) if type(manifest[n]) != list else manifest[n]
126126
files.append(script_filename)
127-
manifest.update({n : files})
127+
manifest.update({n: files})
128128
else:
129129
manifest.update({n: [script_filename]})
130130
return manifest
131131

132+
132133
def download_models(version_matches, manifest):
133134
# Download all models if torch version is different than model version
134135
if not version_matches:
@@ -142,8 +143,8 @@ def download_models(version_matches, manifest):
142143
if (m["path"] == "both" and os.path.exists(scripted_filename) and os.path.exists(traced_filename)) or \
143144
(m["path"] == "script" and os.path.exists(scripted_filename)) or \
144145
(m["path"] == "trace" and os.path.exists(traced_filename)):
145-
print("Skipping {} ".format(n))
146-
continue
146+
print("Skipping {} ".format(n))
147+
continue
147148
manifest = get(n, m, manifest)
148149

149150

@@ -184,4 +185,5 @@ def main():
184185
f.write(record)
185186
f.truncate()
186187

187-
main()
188+
189+
main()

0 commit comments

Comments
 (0)