Skip to content

Commit 738df2d

Browse files
update input normalization values and hyperparams in cifar10 runscripts.
Signed-off-by: Ranganath Krishnan <[email protected]>
1 parent 8f155bd commit 738df2d

File tree

3 files changed

+15
-15
lines changed

3 files changed

+15
-15
lines changed

bayesian_torch/examples/main_bayesian_cifar.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,13 @@
5050
help='manual epoch number (useful on restarts)')
5151
parser.add_argument('-b',
5252
'--batch-size',
53-
default=512,
53+
default=128,
5454
type=int,
5555
metavar='N',
56-
help='mini-batch size (default: 512)')
56+
help='mini-batch size (default: 128)')
5757
parser.add_argument('--lr',
5858
'--learning-rate',
59-
default=0.1,
59+
default=0.001,
6060
type=float,
6161
metavar='LR',
6262
help='initial learning rate')
@@ -67,7 +67,7 @@
6767
help='momentum')
6868
parser.add_argument('--weight-decay',
6969
'--wd',
70-
default=1e-4,
70+
default=5e-4,
7171
type=float,
7272
metavar='W',
7373
help='weight decay (default: 5e-4)')
@@ -223,8 +223,8 @@ def main():
223223
os.makedirs(logger_dir)
224224
tb_writer = SummaryWriter(logger_dir)
225225

226-
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
227-
std=[0.229, 0.224, 0.225])
226+
normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
227+
std=[0.2023, 0.1994, 0.2010])
228228

229229
train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
230230
root='./data',

bayesian_torch/examples/main_bayesian_flipout_cifar.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,13 @@
5151
help='manual epoch number (useful on restarts)')
5252
parser.add_argument('-b',
5353
'--batch-size',
54-
default=512,
54+
default=128,
5555
type=int,
5656
metavar='N',
57-
help='mini-batch size (default: 512)')
57+
help='mini-batch size (default: 128)')
5858
parser.add_argument('--lr',
5959
'--learning-rate',
60-
default=0.1,
60+
default=0.001,
6161
type=float,
6262
metavar='LR',
6363
help='initial learning rate')
@@ -68,7 +68,7 @@
6868
help='momentum')
6969
parser.add_argument('--weight-decay',
7070
'--wd',
71-
default=1e-4,
71+
default=5e-4,
7272
type=float,
7373
metavar='W',
7474
help='weight decay (default: 5e-4)')
@@ -170,8 +170,8 @@ def main():
170170
os.makedirs(logger_dir)
171171
tb_writer = SummaryWriter(logger_dir)
172172

173-
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
174-
std=[0.229, 0.224, 0.225])
173+
normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
174+
std=[0.2023, 0.1994, 0.2010])
175175

176176
train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
177177
root='./data',

bayesian_torch/examples/main_deterministic_cifar.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
help='momentum')
6666
parser.add_argument('--weight-decay',
6767
'--wd',
68-
default=1e-4,
68+
default=5e-4,
6969
type=float,
7070
metavar='W',
7171
help='weight decay (default: 5e-4)')
@@ -157,8 +157,8 @@ def main():
157157
os.makedirs(logger_dir)
158158
tb_writer = SummaryWriter(logger_dir)
159159

160-
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
161-
std=[0.229, 0.224, 0.225])
160+
normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
161+
std=[0.2023, 0.1994, 0.2010])
162162

163163
train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
164164
root='./data',

0 commit comments

Comments
 (0)