Skip to content

Commit aeab361

Browse files
committed
cast class prediction to float for theano compatibility
1 parent c486dd8 commit aeab361

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

P1B3/p1b3_baseline.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,12 +199,13 @@ def evaluate_model(model, generator, samples, metric, category_cutoffs=[0.]):
199199
y_pred = np.concatenate((y_pred, y_batch_pred)) if y_pred is not None else y_batch_pred
200200
count += len(y_batch)
201201

202-
loss = evaluate_keras_metric(y_true, y_pred, metric)
202+
loss = evaluate_keras_metric(y_true.astype(np.float32), y_pred.astype(np.float32), metric)
203203

204204
y_true_class = np.digitize(y_true, category_cutoffs)
205205
y_pred_class = np.digitize(y_pred, category_cutoffs)
206206

207-
acc = evaluate_keras_metric(y_true_class, y_pred_class, 'binary_accuracy') # works for multiclass labels as well
207+
# theano does not like integer input
208+
acc = evaluate_keras_metric(y_true_class.astype(np.float32), y_pred_class.astype(np.float32), 'binary_accuracy') # works for multiclass labels as well
208209

209210
return loss, acc, y_true, y_pred, y_true_class, y_pred_class
210211

0 commit comments

Comments
 (0)