1- from collections import OrderedDict
2- import torch , yaml
1+ import os
2+ import yaml
3+ import torch
34import torch .nn as nn
45import torch .nn .functional as F
56from torchvision import models , transforms
67import gradio as gr
78
9+ # Load config and class names
810cfg = yaml .safe_load (open ("config/prod.yaml" ))
11+ with open ("class_names.txt" ) as f :
12+ class_names = [line .strip () for line in f ]
913
10- # 1. Recreate model class
11- def build_model (num_classes ):
12- """
13- Builds an EfficientNet-B2 model with a custom classification head.
14-
15- Args:
16- num_classes (int): Number of output classes for the classification head.
14+ # Build and load model
1715
18- Returns:
19- nn.Module: The modified EfficientNet-B2 model.
20- """
16+ def build_model (num_classes : int ) -> nn .Module :
2117 model = models .efficientnet_b2 (weights = models .EfficientNet_B2_Weights .IMAGENET1K_V1 )
2218 in_features = model .classifier [1 ].in_features
2319 model .classifier [1 ] = nn .Linear (in_features , num_classes )
2420 return model
2521
26- # 2. Load class names
27- # Load class names from file
28- with open ("class_names.txt" ) as f :
29- class_names = [line .strip () for line in f ]
3022
31- # 3. Build and load the model
32- num_classes = len (class_names )
33- model = build_model (num_classes )
34-
35- # If you see _orig_mod keys, strip the prefix! (Due to possibilty of saving compiled version of model during training)
36- ckpt = torch .load ("output/model.pth" , map_location = 'cpu' )
37- new_state_dict = OrderedDict ()
38- for k , v in ckpt .items ():
39- if k .startswith ('_orig_mod.' ):
40- new_state_dict [k [len ('_orig_mod.' ):]] = v
41- else :
42- new_state_dict [k ] = v
23+ def load_model (path : str , num_classes : int ) -> nn .Module :
24+ model = build_model (num_classes )
25+ state = torch .load (path , map_location = "cpu" )
26+ state = {k .replace ("_orig_mod." , "" ): v for k , v in state .items ()}
27+ model .load_state_dict (state )
28+ model .eval ()
29+ return model
4330
44- model .load_state_dict (new_state_dict )
45- model .eval ()
31+ model = load_model ("output/model.pth" , len (class_names ))
4632
47- # 4. Preprocessing: same as test transforms in train.py
33+ # Preprocessing (must match training)
4834preprocess = transforms .Compose ([
4935 transforms .Resize (256 ),
5036 transforms .CenterCrop (cfg ["estimator" ]["hyperparameters" ]["img-size" ]),
5137 transforms .ToTensor (),
52- transforms .Normalize ([0.485 ,0.456 ,0.406 ],
53- [0.229 ,0.224 ,0.225 ])
38+ transforms .Normalize ([0.485 , 0.456 , 0.406 ], [0.229 , 0.224 , 0.225 ]),
5439])
5540
56- # 5. Inference function
41+
5742def predict (image ):
5843 image = preprocess (image ).unsqueeze (0 )
5944 with torch .no_grad ():
60- outputs = model (image ) # shape: [1, 101]
61- probs = F .softmax (outputs , dim = 1 ).squeeze ().cpu ().numpy () # shape: [101]
62- sorted_indices = probs .argsort ()[::- 1 ] # descending order
63- result = {class_names [i ]: float (probs [i ]) for i in sorted_indices }
64- return result
45+ outputs = model (image )
46+ probs = F .softmax (outputs , dim = 1 )[0 ]
47+ return {class_names [i ]: float (probs [i ]) for i in probs .argsort (descending = True )}
48+
49+ # Example images for the UI
50+ example_dir = "examples"
51+ if os .path .isdir (example_dir ):
52+ examples = [
53+ [os .path .join (example_dir , f )]
54+ for f in os .listdir (example_dir )
55+ if f .lower ().endswith ((".png" , ".jpg" , ".jpeg" ))
56+ ]
57+ else :
58+ examples = None
6559
66- # 6. Gradio app
60+ # Launch Gradio app
6761gr .Interface (
6862 fn = predict ,
6963 inputs = gr .Image (type = "pil" ),
70- outputs = gr .Label (num_top_classes = 101 , label = "Class Probabilities" ),
71- title = "Food101 Classifier"
72- ).launch ()
64+ outputs = gr .Label (num_top_classes = 5 , label = "Top Classes" ),
65+ title = "Food101 Classifier" ,
66+ examples = examples ,
67+ ).launch ()
0 commit comments