4
4
import logging
5
5
6
6
import numpy as np
7
- import pandas as pd
8
7
9
8
import mxnet as mx
10
9
from mxnet .io import DataBatch , DataIter
11
10
12
- # For non-interactive plotting
13
- import matplotlib as mpl
14
- mpl .use ('Agg' )
15
- import matplotlib .pyplot as plt
11
+ # # For non-interactive plotting
12
+ # import matplotlib as mpl
13
+ # mpl.use('Agg')
14
+ # import matplotlib.pyplot as plt
16
15
17
16
import p1b3
18
17
@@ -232,8 +231,6 @@ def main():
232
231
args = parser .parse_args ()
233
232
print ('Args:' , args )
234
233
235
- # it = RegressionDataIter()
236
-
237
234
loggingLevel = logging .DEBUG if args .verbose else logging .INFO
238
235
logging .basicConfig (level = loggingLevel , format = '' )
239
236
@@ -264,6 +261,7 @@ def main():
264
261
net = mx .sym .Activation (data = net , act_type = args .activation )
265
262
if args .pool :
266
263
net = mx .sym .Pooling (data = net , pool_type = "max" , kernel = (args .pool , 1 ), stride = (1 , 1 ))
264
+ net = mx .sym .Flatten (data = net )
267
265
268
266
for layer in args .dense :
269
267
if layer :
@@ -288,10 +286,12 @@ def main():
288
286
label_names = ('growth' ,),
289
287
context = devices )
290
288
289
+ initializer = mx .init .Xavier (factor_type = "in" , magnitude = 2.34 )
291
290
mod .fit (train_iter , eval_data = val_iter ,
292
291
eval_metric = args .loss ,
293
292
optimizer = args .optimizer ,
294
293
num_epoch = args .epochs ,
294
+ initializer = initializer ,
295
295
batch_end_callback = mx .callback .Speedometer (args .batch_size , 20 ))
296
296
297
297
0 commit comments