4
4
from torch .utils .data import DataLoader
5
5
from torchvision import datasets , transforms
6
6
7
+ import logging
8
+
7
9
import example_setup as bmk
8
10
import darts
9
11
import candle
13
15
)
14
16
15
17
18
+ logging .basicConfig (level = logging .INFO )
19
+ logger = logging .getLogger ("darts_advanced" )
20
+
21
+
16
22
def initialize_parameters ():
17
23
""" Initialize the parameters for the Advanced example """
18
24
@@ -89,10 +95,10 @@ def run(params):
89
95
90
96
scheduler .step ()
91
97
lr = scheduler .get_lr ()[0 ]
92
- print (f'\n Epoch: { epoch } lr: { lr } ' )
98
+ logger . info (f'\n Epoch: { epoch } lr: { lr } ' )
93
99
94
100
genotype = model .genotype ()
95
- print (f'Genotype: { genotype } \n ' )
101
+ logger . info (f'Genotype: { genotype } \n ' )
96
102
97
103
train (
98
104
trainloader ,
@@ -160,7 +166,7 @@ def train(trainloader,
160
166
meter .update_batch_accuracy (prec1 , batch_size )
161
167
162
168
if step % args .log_interval == 0 :
163
- print (f'Step: { step } loss: { meter .loss_meter .avg :.4} ' )
169
+ logger . info (f'Step: { step } loss: { meter .loss_meter .avg :.4} ' )
164
170
165
171
meter .update_epoch ()
166
172
meter .save (args .savepath )
@@ -185,7 +191,7 @@ def validate(validloader, model, criterion, args, tasks, meter, device):
185
191
meter .update_batch_accuracy (prec1 , batch_size )
186
192
187
193
if step % args .log_interval == 0 :
188
- print (f'>> Validation: { step } loss: { meter .loss_meter .avg :.4} ' )
194
+ logger . info (f'>> Validation: { step } loss: { meter .loss_meter .avg :.4} ' )
189
195
190
196
meter .update_epoch ()
191
197
meter .save (args .savepath )
0 commit comments