Skip to content

Commit 62b7249

Browse files
committed
Support for new Tensorflow
1 parent 38e6d30 commit 62b7249

File tree

1 file changed

+15
-8
lines changed

1 file changed

+15
-8
lines changed

gestureCNN.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
I didnt spend much time on this behavior, but if someone has answer to this then please do comment and let me know.
2626
ValueError: Negative dimension size caused by subtracting 3 from 1 for 'conv2d_1/convolution' (op: 'Conv2D') with input shapes: [?,1,200,200], [3,3,200,32].
2727
'''
28-
K.set_image_dim_ordering('th')
28+
#K.set_image_dim_ordering('th')
29+
K.set_image_data_format('channels_first')
2930

3031

3132
import numpy as np
@@ -217,8 +218,7 @@ def guessGesture(model, img):
217218
#prob_array = model.predict_proba(rimage)
218219

219220
prob_array = get_output([rimage, 0])[0]
220-
221-
#print prob_array
221+
#print('prob_array: ',prob_array)
222222

223223
d = {}
224224
i = 0
@@ -327,27 +327,34 @@ def trainModel(model):
327327
hist = model.fit(X_train, Y_train, batch_size=batch_size, epochs=nb_epoch,
328328
verbose=1, validation_split=0.2)
329329

330-
visualizeHis(hist)
331-
332330
ans = input("Do you want to save the trained weights - y/n ?")
333331
if ans == 'y':
334332
filename = input("Enter file name - ")
335333
fname = path + str(filename) + ".hdf5"
336334
model.save_weights(fname,overwrite=True)
337335
else:
338336
model.save_weights("newWeight.hdf5",overwrite=True)
337+
338+
visualizeHis(hist)
339339

340340
# Save model as well
341341
# model.save("newModel.hdf5")
342342
#%%
343343

344344
def visualizeHis(hist):
345345
# visualizing losses and accuracy
346-
346+
keylist = hist.history.keys()
347+
#print(hist.history.keys())
347348
train_loss=hist.history['loss']
348349
val_loss=hist.history['val_loss']
349-
train_acc=hist.history['acc']
350-
val_acc=hist.history['val_acc']
350+
351+
#Tensorflow new updates seem to have different key name
352+
if 'acc' in keylist:
353+
train_acc=hist.history['acc']
354+
val_acc=hist.history['val_acc']
355+
else:
356+
train_acc=hist.history['accuracy']
357+
val_acc=hist.history['val_accuracy']
351358
xc=range(nb_epoch)
352359

353360
plt.figure(1,figsize=(7,5))

0 commit comments

Comments
 (0)