Skip to content

Commit a7c75c8

Browse files
committed
fixed unable to read state_dict issue
1 parent 9cbaf31 commit a7c75c8

File tree

5 files changed

+128
-8
lines changed

5 files changed

+128
-8
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ Supports local development, SageMaker training, flexible dataset prep, Weights &
2929
│ └── train.py # Main training script
3030
├── .env.example # Example for API keys/secrets
3131
├── requirements.txt # Pip dependencies
32-
├── gradio_app.py # Gradio tnterface
32+
├── gradio_app.py # Gradio interface
3333
├── README.md
3434
└── .gitignore
3535
```

class_names.txt

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
apple_pie
2+
baby_back_ribs
3+
baklava
4+
beef_carpaccio
5+
beef_tartare
6+
beet_salad
7+
beignets
8+
bibimbap
9+
bread_pudding
10+
breakfast_burrito
11+
bruschetta
12+
caesar_salad
13+
cannoli
14+
caprese_salad
15+
carrot_cake
16+
ceviche
17+
cheese_plate
18+
cheesecake
19+
chicken_curry
20+
chicken_quesadilla
21+
chicken_wings
22+
chocolate_cake
23+
chocolate_mousse
24+
churros
25+
clam_chowder
26+
club_sandwich
27+
crab_cakes
28+
creme_brulee
29+
croque_madame
30+
cup_cakes
31+
deviled_eggs
32+
donuts
33+
dumplings
34+
edamame
35+
eggs_benedict
36+
escargots
37+
falafel
38+
filet_mignon
39+
fish_and_chips
40+
foie_gras
41+
french_fries
42+
french_onion_soup
43+
french_toast
44+
fried_calamari
45+
fried_rice
46+
frozen_yogurt
47+
garlic_bread
48+
gnocchi
49+
greek_salad
50+
grilled_cheese_sandwich
51+
grilled_salmon
52+
guacamole
53+
gyoza
54+
hamburger
55+
hot_and_sour_soup
56+
hot_dog
57+
huevos_rancheros
58+
hummus
59+
ice_cream
60+
lasagna
61+
lobster_bisque
62+
lobster_roll_sandwich
63+
macaroni_and_cheese
64+
macarons
65+
miso_soup
66+
mussels
67+
nachos
68+
omelette
69+
onion_rings
70+
oysters
71+
pad_thai
72+
paella
73+
pancakes
74+
panna_cotta
75+
peking_duck
76+
pho
77+
pizza
78+
pork_chop
79+
poutine
80+
prime_rib
81+
pulled_pork_sandwich
82+
ramen
83+
ravioli
84+
red_velvet_cake
85+
risotto
86+
samosa
87+
sashimi
88+
scallops
89+
seaweed_salad
90+
shrimp_and_grits
91+
spaghetti_bolognese
92+
spaghetti_carbonara
93+
spring_rolls
94+
steak
95+
strawberry_shortcake
96+
sushi
97+
tacos
98+
takoyaki
99+
tiramisu
100+
tuna_tartare
101+
waffles

create_dataset_names.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import os
2+
3+
train_dir = "data/sample/train"
4+
class_names = sorted(os.listdir(train_dir))
5+
6+
with open("class_names.txt", "w") as f:
7+
for name in class_names:
8+
f.write(name + "\n")

gradio_app.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
import os, torch, yaml
1+
from collections import OrderedDict
2+
import torch, yaml
23
import torch.nn as nn
34
import torch.nn.functional as F
45
from torchvision import models, transforms
@@ -23,20 +24,30 @@ def build_model(num_classes):
2324
return model
2425

2526
# 2. Load class names
26-
# Assuming same folder structure as the default flags for train.py's train-dir
27-
train_dir = "data/sample/train"
28-
class_names = sorted(os.listdir(train_dir))
27+
# Load class names from file
28+
with open("class_names.txt") as f:
29+
class_names = [line.strip() for line in f]
2930

3031
# 3. Build and load the model
3132
num_classes = len(class_names)
3233
model = build_model(num_classes)
33-
model.load_state_dict(torch.load("output/model.pth", map_location="cpu"))
34+
35+
# If you see _orig_mod keys, strip the prefix! (Due to possibilty of saving compiled version of model during training)
36+
ckpt = torch.load("output/model.pth", map_location='cpu')
37+
new_state_dict = OrderedDict()
38+
for k, v in ckpt.items():
39+
if k.startswith('_orig_mod.'):
40+
new_state_dict[k[len('_orig_mod.'):]] = v
41+
else:
42+
new_state_dict[k] = v
43+
44+
model.load_state_dict(new_state_dict)
3445
model.eval()
3546

3647
# 4. Preprocessing: same as test transforms in train.py
3748
preprocess = transforms.Compose([
3849
transforms.Resize(256),
39-
transforms.CenterCrop(cfg["estimator"]["hyperparameters"]["img_size"]),
50+
transforms.CenterCrop(cfg["estimator"]["hyperparameters"]["img-size"]),
4051
transforms.ToTensor(),
4152
transforms.Normalize([0.485,0.456,0.406],
4253
[0.229,0.224,0.225])

scripts/download_full.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def main():
1616
shutil.rmtree(out)
1717
out.mkdir(parents=True)
1818

19-
# download MNIST into .cache
19+
# download dataset into .cache
2020
cache = Path(".cache")
2121
ds_train = datasets.Food101(cache, split="train", download=True)
2222
ds_test = datasets.Food101(cache, split="test", download=True)

0 commit comments

Comments
 (0)