Skip to content

Commit 271dfa4

Browse files
committed
Pass num_classes to model
1 parent 0c30c74 commit 271dfa4

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

src/main.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,14 +174,17 @@ def get_model_by_name(system_config, run_config) -> torch.nn.Module:
174174
name = run_config["model"]
175175
repo = system_config["models"][name]["torch.hub.load"]["repo"]
176176
model = system_config["models"][name]["torch.hub.load"]["model"]
177+
num_classes = system_config["datasets"][run_config["dataset"]]["num-classes"]
177178

178179
if name == "dummy":
179180
return DummyModel()
180181
elif name == "ann":
181182
return ANN()
182183
else:
183184
try:
184-
return torch.hub.load(repo, model, pretrained=False, trust_repo=True)
185+
return torch.hub.load(
186+
repo, model, pretrained=False, trust_repo=True, num_classes=num_classes
187+
)
185188
except:
186189
raise NotImplementedError(
187190
f"An error occurred while loading {model}" f" from torch.hub."

0 commit comments

Comments
 (0)