Skip to content

Commit f986181

Browse files
committed
tuning the toy model
1 parent 0e3464d commit f986181

File tree

10 files changed

+33901
-96
lines changed

10 files changed

+33901
-96
lines changed

base/base_config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,10 @@ def __init__(self):
102102
###########################
103103
cfg.MODEL = CN()
104104
cfg.MODEL.NAME = ""
105+
cfg.MODEL.MLP = CN()
106+
cfg.MODEL.MLP.INPUT_DIM = 784
107+
cfg.MODEL.MLP.HIDDEN_DIM = [128, 64]
108+
cfg.MODEL.MLP.OUTPUT_DIM = 10
105109
# Path to model weights (for initialization)
106110
cfg.MODEL.INIT_WEIGHTS = ""
107111
# Definition of embedding layers

base/base_dataset.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,9 +158,12 @@ def split_train_val(self, val_percent, seed=42):
158158
val_indices = indices[:val_size]
159159
train_indices = indices[val_size:]
160160

161+
# Save original train data before reassigning
162+
original_train = self._train
163+
161164
# Create new splits
162-
self._train = [self._train[i] for i in train_indices]
163-
self._val = [self._train[i] for i in val_indices]
165+
self._train = [original_train[i] for i in train_indices]
166+
self._val = [original_train[i] for i in val_indices]
164167
logger.info(f"Split complete: {len(self._train)} train, {len(self._val)} val")
165168

166169
def download_data(self, url, dst, from_gdrive=True):

base/base_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def set_model_mode(self, mode="train", names=None):
242242
> close_writer
243243
> write_scalar
244244
"""
245-
def init_writer(self, log_dir, extra_config=None):
245+
def init_writer(self, extra_config=None):
246246
# Only initialize writer on main process
247247
if not self.is_main_process():
248248
return

config/toy_trainer_config.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@ DATALOADER:
2929

3030
MODEL:
3131
NAME: ToyModel
32-
INIT_WEIGHTS:
32+
MLP:
33+
INPUT_DIM: 784
34+
HIDDEN_DIM: [128, 64]
35+
OUTPUT_DIM: 10
3336

3437
OPTIM:
3538
NAME: adam

dataset/HDTF_TFHP.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@ class HDTF_TFHP(DatasetBase):
2323

2424
def __init__(self, cfg):
2525
# data config and path
26-
root = os.path.abspath(os.path.expanduser(cfg.ROOT))
27-
self.dataset_dir = os.path.join(root, cfg.NAME)
26+
self.dataset_dir = os.path.join(cfg.ROOT, cfg.NAME)
2827
lmdb_path = self.dataset_dir
2928
split_path = [os.path.join(self.dataset_dir, cfg.HDTF_TFHP.TRAIN),
3029
os.path.join(self.dataset_dir, cfg.HDTF_TFHP.VAL),

dataset/MINIST.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,6 @@ class MNIST(DatasetBase):
1717
"""
1818

1919
def __init__(self, cfg):
20-
# Data config and path
21-
print(cfg.ROOT)
22-
root = os.path.abspath(os.path.expanduser(cfg.ROOT))
23-
print(root)
24-
self.dataset_dir = os.path.join(root, cfg.NAME)
25-
print(self.dataset_dir)
26-
os.makedirs(self.dataset_dir, exist_ok=True) # Create directory if not exists
27-
2820
# Define transformations
2921
self.transform = transforms.Compose([
3022
transforms.ToTensor(),
@@ -34,13 +26,13 @@ def __init__(self, cfg):
3426
# Load MNIST train and test datasets (will download if not exists)
3527
try:
3628
train_dataset = datasets.MNIST(
37-
root=self.dataset_dir,
29+
root=cfg.ROOT,
3830
train=True,
3931
download=True,
4032
transform=self.transform
4133
)
4234
test_dataset = datasets.MNIST(
43-
root=self.dataset_dir,
35+
root=cfg.ROOT,
4436
train=False,
4537
download=True,
4638
transform=self.transform

models/lib/network/mlp.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import functools
2+
import torch
23
import torch.nn as nn
34

45
# from ..head.build import HEAD_REGISTRY
@@ -10,16 +11,23 @@ def __init__(
1011
self,
1112
in_features=2048,
1213
hidden_layers=[],
14+
out_features=None,
1315
activation="relu",
1416
bn=True,
1517
dropout=0.0,
18+
1619
):
1720
super().__init__()
1821
if isinstance(hidden_layers, int):
1922
hidden_layers = [hidden_layers]
2023

2124
assert len(hidden_layers) > 0
22-
self.out_features = hidden_layers[-1]
25+
26+
# If out_features is not specified, use the last hidden layer dimension
27+
if out_features is None:
28+
out_features = hidden_layers[-1]
29+
self.out_features = out_features
30+
self.in_features = in_features
2331

2432
mlp = []
2533

@@ -33,15 +41,23 @@ def __init__(
3341
for hidden_dim in hidden_layers:
3442
mlp += [nn.Linear(in_features, hidden_dim)]
3543
if bn:
36-
mlp += [nn.BatchNorm1d(hidden_dim)]
44+
mlp += [nn.LayerNorm(hidden_dim)]
3745
mlp += [act_fn()]
3846
if dropout > 0:
3947
mlp += [nn.Dropout(dropout)]
4048
in_features = hidden_dim
4149

50+
# Add final projection layer if output dimension differs from last hidden layer
51+
if out_features != hidden_layers[-1]:
52+
mlp += [nn.Linear(hidden_layers[-1], out_features)]
53+
4254
self.mlp = nn.Sequential(*mlp)
4355

4456
def forward(self, x):
57+
# Flatten input if it has more than 2 dimensions
58+
if x.dim() > 2:
59+
x = x.view(x.size(0), -1)
60+
4561
return self.mlp(x)
4662

4763

models/toymodel.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77

88
class ToyModel(nn.Module):
99
"""Simple MLP for MNIST digit recognition"""
10-
def __init__(self):
10+
def __init__(self, cfg):
1111
super(ToyModel, self).__init__()
12-
self.net = MLP(in_features=1*28*28,
13-
hidden_layers=[20, 10],
12+
self.net = MLP(in_features=cfg.INPUT_DIM,
13+
hidden_layers=cfg.HIDDEN_DIM,
14+
out_features=cfg.OUTPUT_DIM,
1415
activation='relu',
1516
bn=True, dropout=0.1)
1617

0 commit comments

Comments
 (0)