Skip to content

Commit b7beb37

Browse files
authored
Arm backend: update the pretrained flag (#15739)
cc @freddan80 @per @zingo @oscarandersson8218 @digantdesai Signed-off-by: Tirui Wu <[email protected]>
1 parent a44f68d commit b7beb37

File tree

1 file changed

+3
-9
lines changed

1 file changed

+3
-9
lines changed

examples/models/deit_tiny/model.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import logging
77

88
import torch
9-
from torchvision import transforms
109

1110
try:
1211
import timm # type: ignore
@@ -15,8 +14,6 @@
1514
"timm package is required for builtin 'deit_tiny'. Install timm."
1615
) from e
1716

18-
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
19-
2017
from ..model_base import EagerModelBase
2118

2219

@@ -27,16 +24,13 @@ def __init__(self): # type: ignore[override]
2724

2825
def get_eager_model(self) -> torch.nn.Module: # type: ignore[override]
2926
logging.info("Loading timm deit_tiny_patch16_224 model")
30-
model = timm.models.deit.deit_tiny_patch16_224(pretrained=False)
31-
model.eval()
27+
model = timm.models.deit.deit_tiny_patch16_224(pretrained=True)
3228
logging.info("Loaded timm deit_tiny_patch16_224 model")
3329
return model
3430

3531
def get_example_inputs(self): # type: ignore[override]
36-
normalize = transforms.Normalize(
37-
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD
38-
)
39-
return (normalize(torch.rand((1, 3, 224, 224))),)
32+
input_shape = (1, 3, 224, 224)
33+
return (torch.randn(input_shape),)
4034

4135

4236
__all__ = ["DeiTTinyModel"]

0 commit comments

Comments
 (0)