Skip to content

Commit a5d44c9

Browse files
committed
Fix meters
The meters in DARTS were updated when they were moved to common, but the logging in P3B5 was not updated.
1 parent 2b10631 commit a5d44c9

File tree

2 files changed

+14
-12
lines changed

2 files changed

+14
-12
lines changed

Pilot3/P3B5/p3b5_baseline_pytorch.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def run(params):
9797
print(f'Genotype: {genotype}')
9898

9999
# training
100-
train_acc, train_loss = train(
100+
train(
101101
trainloader,
102102
validloader,
103103
model,
@@ -112,16 +112,20 @@ def run(params):
112112
)
113113

114114
# validation
115-
valid_acc, valid_loss = infer(validloader, model, criterion, args, tasks, device, valid_meter)
115+
valid_loss = infer(
116+
validloader,
117+
model,
118+
criterion,
119+
args,
120+
tasks,
121+
device,
122+
valid_meter
123+
)
116124

117125
if valid_loss < min_loss:
118126
genotype_store.save_genotype(genotype)
119127
min_loss = valid_loss
120128

121-
print(f'\nEpoch {epoch} stats:')
122-
# darts.log_accuracy(train_acc, 'train')
123-
# darts.log_accuracy(valid_acc, 'valid')
124-
125129

126130
def main():
127131
params = initialize_parameters()

Pilot3/P3B5/p3b5_darts.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,10 @@ def train(trainloader, validloader, model, architecture, criterion, optimizer, l
6666
meter.update_batch_accuracy(prec1, batch_size)
6767

6868
if step % args.log_interval == 0:
69-
print(f'Step: {step} loss: {losses.avg:.4}')
70-
#darts.log_accuracy(top1)
69+
print(f'Step: {step} loss: {meter.loss_meter.avg:.4}')
7170

7271
meter.update_epoch()
7372
meter.save(args.savepath)
74-
return top1, losses.avg
7573

7674

7775
def infer(validloader, model, criterion, args, tasks, device, meter):
@@ -94,12 +92,12 @@ def infer(validloader, model, criterion, args, tasks, device, meter):
9492
meter.update_batch_accuracy(prec1, batch_size)
9593

9694
if step % args.log_interval == 0:
97-
print(f'>> Validation: {step} loss: {losses.avg:.4}')
98-
#darts.log_accuracy(top1, 'valid')
95+
print(f'>> Validation: {step} loss: {meter.loss_meter.avg:.4}')
9996

10097
meter.update_epoch()
10198
meter.save(args.savepath)
102-
return top1, losses.avg
99+
100+
return meter.loss_meter.avg
103101

104102

105103
if __name__=='__main__':

0 commit comments

Comments
 (0)