Skip to content

Commit 885f1d9

Browse files
committed
using dog.jpg as real input default
1 parent eccf7cb commit 885f1d9

File tree

1 file changed

+23
-1
lines changed

1 file changed

+23
-1
lines changed

examples/models/mobilenet_v2/model.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515

1616

1717
class MV2Model(EagerModelBase):
18-
def __init__(self):
18+
def __init__(self, useRealInput=True):
19+
self.useRealInput = useRealInput
1920
pass
2021

2122
def get_eager_model(self) -> torch.nn.Module:
@@ -26,6 +27,27 @@ def get_eager_model(self) -> torch.nn.Module:
2627

2728
def get_example_inputs(self):
2829
tensor_size = (1, 3, 224, 224)
30+
input_batch = (torch.randn(tensor_size),)
31+
if self.useRealInput:
32+
logging.info("Loaded real input image dog.jpg")
33+
import urllib
34+
url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")
35+
try:
36+
urllib.URLopener().retrieve(url, filename)
37+
except:
38+
urllib.request.urlretrieve(url, filename)
39+
from PIL import Image
40+
from torchvision import transforms
41+
input_image = Image.open(filename)
42+
preprocess = transforms.Compose([
43+
transforms.Resize(256),
44+
transforms.CenterCrop(224),
45+
transforms.ToTensor(),
46+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
47+
])
48+
input_tensor = preprocess(input_image)
49+
input_batch = input_tensor.unsqueeze(0)
50+
input_batch = (input_batch,)
2951
return (torch.randn(tensor_size),)
3052

3153

0 commit comments

Comments
 (0)