Skip to content

Commit c35326d

Browse files
committed
minor update
1 parent 0e62677 commit c35326d

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

P1B3/p1b3_baseline_mxnet.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,14 @@
44
import logging
55

66
import numpy as np
7-
import pandas as pd
87

98
import mxnet as mx
109
from mxnet.io import DataBatch, DataIter
1110

12-
# For non-interactive plotting
13-
import matplotlib as mpl
14-
mpl.use('Agg')
15-
import matplotlib.pyplot as plt
11+
# # For non-interactive plotting
12+
# import matplotlib as mpl
13+
# mpl.use('Agg')
14+
# import matplotlib.pyplot as plt
1615

1716
import p1b3
1817

@@ -232,8 +231,6 @@ def main():
232231
args = parser.parse_args()
233232
print('Args:', args)
234233

235-
# it = RegressionDataIter()
236-
237234
loggingLevel = logging.DEBUG if args.verbose else logging.INFO
238235
logging.basicConfig(level=loggingLevel, format='')
239236

@@ -264,6 +261,7 @@ def main():
264261
net = mx.sym.Activation(data=net, act_type=args.activation)
265262
if args.pool:
266263
net = mx.sym.Pooling(data=net, pool_type="max", kernel=(args.pool, 1), stride=(1, 1))
264+
net = mx.sym.Flatten(data=net)
267265

268266
for layer in args.dense:
269267
if layer:
@@ -288,10 +286,12 @@ def main():
288286
label_names=('growth',),
289287
context=devices)
290288

289+
initializer = mx.init.Xavier(factor_type="in", magnitude=2.34)
291290
mod.fit(train_iter, eval_data=val_iter,
292291
eval_metric=args.loss,
293292
optimizer=args.optimizer,
294293
num_epoch=args.epochs,
294+
initializer=initializer,
295295
batch_end_callback = mx.callback.Speedometer(args.batch_size, 20))
296296

297297

0 commit comments

Comments
 (0)