Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,6 @@ def __new__(

if pretrained:
cache_dir = Path.home() / ".cache" / "torch" / "hub" / "checkpoints"
download_model(net=model, model_name=f"{model_name}", local_model_store_dir_path=str(cache_dir))
download_model(net=model, model_name=f"{model_name}c", local_model_store_dir_path=str(cache_dir))
Copy link
Member

@sovrasov sovrasov Nov 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Ilya for the contribution. I think here we can not just replace this weights without a set of comprehensive experiments. 100% positive correlation with IN-top1 is not given, there could be sudden outlier which we'd like to avoid.

To tackle that, I'd propose moving this version parameter to init of EfficientNetBackbone and forwarding it from the endpoint classes (EfficientNetMulticlassCls etc). Once done, the model version is configurable from model recipe yaml file. Corner case of b0 and b1 should be handled.

print(f"Download model weight in {cache_dir!s}")
return model
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ class TestOTXEfficientNet:
"efficientnet_b8",
],
)
def test_forward(self, model_name):
model = EfficientNetBackbone(model_name, pretrained=None)
@pytest.mark.parametrize("pretrained", [True, False])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please, don't download anything here to avoid increase of UTs duration and reliability. You could mock download_model if the idea is to check that the URL is correct somehow.

def test_forward(self, model_name, pretrained):
model = EfficientNetBackbone(model_name, pretrained=pretrained)
assert model(torch.randn(1, 3, 244, 244))[0].shape[-1] == 8
assert model(torch.randn(1, 3, 244, 244))[0].shape[-2] == 8

Expand Down
Loading