1515
1616
1717class MV2Model (EagerModelBase ):
18- def __init__ (self ):
18+ def __init__ (self , useRealInput = True ):
19+ self .useRealInput = useRealInput
1920 pass
2021
2122 def get_eager_model (self ) -> torch .nn .Module :
@@ -26,6 +27,27 @@ def get_eager_model(self) -> torch.nn.Module:
2627
2728 def get_example_inputs (self ):
2829 tensor_size = (1 , 3 , 224 , 224 )
30+ input_batch = (torch .randn (tensor_size ),)
31+ if self .useRealInput :
32+ logging .info ("Loaded real input image dog.jpg" )
33+ import urllib
34+ url , filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg" , "dog.jpg" )
35+ try :
36+ urllib .URLopener ().retrieve (url , filename )
37+ except :
38+ urllib .request .urlretrieve (url , filename )
39+ from PIL import Image
40+ from torchvision import transforms
41+ input_image = Image .open (filename )
42+ preprocess = transforms .Compose ([
43+ transforms .Resize (256 ),
44+ transforms .CenterCrop (224 ),
45+ transforms .ToTensor (),
46+ transforms .Normalize (mean = [0.485 , 0.456 , 0.406 ], std = [0.229 , 0.224 , 0.225 ]),
47+ ])
48+ input_tensor = preprocess (input_image )
49+ input_batch = input_tensor .unsqueeze (0 )
50+ input_batch = (input_batch ,)
2951 return (torch .randn (tensor_size ),)
3052
3153
0 commit comments