Skip to content

Commit 9cbaf31

Browse files
committed
fixed compiling issue
1 parent b809759 commit 9cbaf31

File tree

4 files changed

+42
-46
lines changed

4 files changed

+42
-46
lines changed

README.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Supports local development, SageMaker training, flexible dataset prep, Weights &
1010
- Logging and experiment tracking (Weights & Biases)
1111
- Model checkpointing and flexible configuration
1212
- Ready for deployment (Gradio web app)
13-
- Mixed precision training (with `autocast` and `GradScaler`) for improved speed and memory efficiency on GPU
13+
- Gradient clipping, OneCycle LR policy, and mixed precision training (with `autocast` and `GradScaler`) for improved stability and GPU memory efficiency
1414

1515
## Project Structure
1616
```graphql
@@ -37,7 +37,7 @@ Supports local development, SageMaker training, flexible dataset prep, Weights &
3737
## Quick Start
3838
### 1. Clone & Install
3939
```bash
40-
git clone https://github.com/<your-username>/food101-classifier.git
40+
git clone https://github.com/codinglabsong/food101-end2end-classifier-sagemaker-gradio.git
4141
cd food101-classifier
4242
pip install -r requirements.txt
4343
```
@@ -90,8 +90,8 @@ Edit `.env` using `.env.example` as a guide for AWS and wandb keys.
9090
> The preprocessing pipeline (image resizing, cropping, normalization) **must be identical** between training and inference (including Gradio app or deployment).
9191
>
9292
> - All transforms should use parameters from `config/prod.yaml` (or your config file).
93-
> - The value of `img_size` used for training and inference must always be ≤ 256, since images are first resized so their short edge is 256 before center cropping.
94-
> - **Do not set `img_size` greater than 256.** This would result in errors or ineffective cropping during inference.
93+
> - The value of `img_size` used for training and inference must always be ≤ 512, since images are first resized so their short edge is 512 before center cropping.
94+
> - **Do not set `img_size` greater than 512.** This would result in errors or ineffective cropping during inference.
9595

9696
**Best practice:**
9797
Update only your config file (not hardcoded values) when changing image size or normalization, and always reload configs in both training and inference code.
@@ -112,8 +112,8 @@ This project includes an interactive Gradio app for making predictions with the
112112

113113
## Requirements
114114
- See `requirements.txt`
115-
- Python 3.8
116-
- PyTorch >= 2.2
115+
- Python >= 3.9
116+
- PyTorch >= 2.6
117117

118118
## Contributing
119119
Open to issues and pull requests!
@@ -128,4 +128,4 @@ This project is licensed under the MIT License.
128128

129129
## Tips:
130130
- .env.example helps keep secrets out of git.
131-
- .gitignore: Don't track datasets, outputs, or .env.
131+
- .gitignore: Don't track datasets, outputs, wandb, or .env.

config/prod.yaml

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
estimator:
22
hyperparameters:
33
seed: 42
4-
batch_size: 128
5-
num_epochs_phase1: 10
6-
num_epochs_phase2: 8
7-
lr_head: 4e-3
8-
lr_backbone: 4e-4
4+
batch-size: 128
5+
num-epochs-phase1: 8
6+
num-epochs-phase2: 10
7+
lr-head: 4e-3
8+
lr-backbone: 5e-4
99
patience: 3
10-
num_workers: 3
11-
img_size: 224
10+
num-workers: 2
11+
img-size: 224
1212
instance_count: 1
13-
instance_type: "ml.m5.xlarge"
14-
framework_version: "2.2.0"
15-
py_version: "py310"
16-
base_job_name: "mnist-cnn"
13+
instance_type: "ml.g4dn.xlarge"
14+
framework_version: "2.6.0"
15+
py_version: "py312"
16+
base_job_name: "food101-classifier"
1717
use_spot_instances: true
18-
max_run: 7200 # seconds
18+
max_run: 10800 # seconds
1919
max_wait: 14400 # seconds, needed when using spot instances or otherwise wait indefinitely

gradio_app.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def build_model(num_classes):
3535

3636
# 4. Preprocessing: same as test transforms in train.py
3737
preprocess = transforms.Compose([
38-
transforms.Resize(512),
38+
transforms.Resize(256),
3939
transforms.CenterCrop(cfg["estimator"]["hyperparameters"]["img_size"]),
4040
transforms.ToTensor(),
4141
transforms.Normalize([0.485,0.456,0.406],

src/train.py

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@ def parse_args():
1515

1616
# hyperparameters sent by the client (same flag names as estimator hyperparameters)
1717
p.add_argument("--seed", type=int, default=42)
18-
p.add_argument("--batch-size", type=int, default=512)
19-
p.add_argument("--num-epochs-phase1", type=int, default=2)
20-
p.add_argument("--num-epochs-phase2", type=int, default=2)
21-
p.add_argument("--lr-head", type=float, default=16e-3)
22-
p.add_argument("--lr-backbone", type=float, default=16e-4)
18+
p.add_argument("--batch-size", type=int, default=128)
19+
p.add_argument("--num-epochs-phase1", type=int, default=10)
20+
p.add_argument("--num-epochs-phase2", type=int, default=8)
21+
p.add_argument("--lr-head", type=float, default=4e-3)
22+
p.add_argument("--lr-backbone", type=float, default=4e-4)
2323
p.add_argument("--patience", type=int, default=3)
2424
p.add_argument("--num-workers", type=int, default=2)
25-
p.add_argument("--img-size", type=int, default=384)
25+
p.add_argument("--img-size", type=int, default=224)
2626

2727
# other variables
2828
p.add_argument("--wandb-project", type=str, default="food101-classifier")
@@ -119,7 +119,7 @@ def main():
119119
[0.229,0.224,0.225])
120120
])
121121
test_tfms = transforms.Compose([
122-
transforms.Resize(512), # shrink so short edge=256
122+
transforms.Resize(256), # shrink so short edge=256
123123
transforms.CenterCrop(cfg.img_size), # take middle window
124124
transforms.ToTensor(),
125125
transforms.Normalize([0.485,0.456,0.406],
@@ -151,7 +151,7 @@ def main():
151151
test_dl = DataLoader(test_ds, batch_size=cfg.batch_size, num_workers=cfg.num_workers, pin_memory=True)
152152

153153
print(f"Data ready. len(train)={len(train_ds)}, len(val)={len(val_ds)}, len(test)={len(test_ds)}")
154-
154+
155155
# ---------- Model Training Preparation ----------
156156
# create the model
157157
def build_model(num_classes: int) -> nn.Module:
@@ -175,14 +175,6 @@ def build_model(num_classes: int) -> nn.Module:
175175
print(f"number of class labels: {len(class_names)}")
176176
model = build_model(len(class_names))
177177

178-
# try compile if supported:
179-
if DEVICE.type == "cuda" and torch.cuda.is_available():
180-
cap = torch.cuda.get_device_properties(DEVICE).major
181-
if cap >= 7:
182-
model = torch.compile(model)
183-
else:
184-
print(f"GPU CC {cap}.x detected - skipping torch.compile()")
185-
186178
criterion = nn.CrossEntropyLoss() # standard multi-class loss
187179

188180
# one epoch function
@@ -192,7 +184,6 @@ def build_model(num_classes: int) -> nn.Module:
192184
else:
193185
scaler = None
194186

195-
step_counters = {'train': 0, 'val': 0}
196187
def epoch_loop (phase: str,
197188
model: nn.Module,
198189
loader: DataLoader,
@@ -240,12 +231,11 @@ def epoch_loop (phase: str,
240231
run_correct += (outputs.argmax(1) == y).sum().item()
241232
imgs_processed += batch_size # add to throughput counter
242233

243-
# wandb: batch logging (train & val only)
244-
if phase in ["train", "val"]:
234+
# wandb: batch logging (train only)
235+
if is_train:
245236
wandb.log({
246-
f"{phase}/batch_loss": loss.item(),
247-
}, step=step_counters[phase])
248-
step_counters[phase] += 1
237+
f"train/batch_loss": loss.item(),
238+
})
249239

250240
if torch.cuda.is_available():
251241
torch.cuda.synchronize() # CPU waits until GPU finishes. More accurate dt.
@@ -278,7 +268,7 @@ def epoch_loop (phase: str,
278268
f"{phase}/loss_scale": loss_scale,
279269
f"{phase}/peak_mem_MB": peak_mem_MB,
280270
})
281-
wandb.log(metrics, step=step_counters[phase] - 1) # ensures logging at the same step as the last batch of that epoch
271+
wandb.log(metrics) # ensures logging at the same step as the last batch of that epoch
282272
return epoch_loss, epoch_acc
283273

284274
# checkpoint helper
@@ -302,7 +292,7 @@ def save_ckpt(state: Dict, filename: str, model_dir: str) -> None:
302292
optimizer,
303293
max_lr=cfg.lr_head,
304294
total_steps=total_steps,
305-
pct_start=0.2, # 20% of total steps for LR warm-up
295+
pct_start=0.35, # 35% of total steps for LR warm-up
306296
anneal_strategy="cos", # cosine annealing down
307297
)
308298

@@ -335,17 +325,23 @@ def save_ckpt(state: Dict, filename: str, model_dir: str) -> None:
335325
print("\nPhase 2: fine-tune")
336326

337327
# unfreeze backbone
328+
print("\nUnfreezing backbone...")
338329
for p in model.parameters():
339330
p.requires_grad = True
331+
332+
if torch.cuda.is_available() and torch.cuda.get_device_properties(DEVICE).major >= 7:
333+
torch.cuda.empty_cache() # free the memory (helpful, but optional)
334+
model = torch.compile(model)
335+
print(f"GPU CC {torch.cuda.get_device_properties(DEVICE).major}.x detected - compiled model")
340336

341-
optimizer = optim.Adam(model.parameters(), lr=cfg.lr_backbone)
337+
optimizer = optim.Adam(model.parameters(), lr=cfg.lr_backbone)
342338
total_steps = cfg.num_epochs_phase2 * n_steps_per_epoch
343339

344340
scheduler = OneCycleLR(
345341
optimizer,
346342
max_lr=cfg.lr_backbone,
347343
total_steps=total_steps,
348-
pct_start=0.2,
344+
pct_start=0.15,
349345
anneal_strategy="cos",
350346
)
351347

0 commit comments

Comments
 (0)