1- import os , torch , yaml
1+ from collections import OrderedDict
2+ import torch , yaml
23import torch .nn as nn
34import torch .nn .functional as F
45from torchvision import models , transforms
@@ -23,20 +24,30 @@ def build_model(num_classes):
2324 return model
2425
2526# 2. Load class names
26- # Assuming same folder structure as the default flags for train.py's train-dir
27- train_dir = "data/sample/train"
28- class_names = sorted ( os . listdir ( train_dir ))
27+ # Load class names from file
28+ with open ( "class_names.txt" ) as f :
29+ class_names = [ line . strip () for line in f ]
2930
3031# 3. Build and load the model
3132num_classes = len (class_names )
3233model = build_model (num_classes )
33- model .load_state_dict (torch .load ("output/model.pth" , map_location = "cpu" ))
34+
35+ # If you see _orig_mod keys, strip the prefix! (Due to possibilty of saving compiled version of model during training)
36+ ckpt = torch .load ("output/model.pth" , map_location = 'cpu' )
37+ new_state_dict = OrderedDict ()
38+ for k , v in ckpt .items ():
39+ if k .startswith ('_orig_mod.' ):
40+ new_state_dict [k [len ('_orig_mod.' ):]] = v
41+ else :
42+ new_state_dict [k ] = v
43+
44+ model .load_state_dict (new_state_dict )
3445model .eval ()
3546
3647# 4. Preprocessing: same as test transforms in train.py
3748preprocess = transforms .Compose ([
3849 transforms .Resize (256 ),
39- transforms .CenterCrop (cfg ["estimator" ]["hyperparameters" ]["img_size " ]),
50+ transforms .CenterCrop (cfg ["estimator" ]["hyperparameters" ]["img-size " ]),
4051 transforms .ToTensor (),
4152 transforms .Normalize ([0.485 ,0.456 ,0.406 ],
4253 [0.229 ,0.224 ,0.225 ])
0 commit comments