-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathsubmit.py
More file actions
70 lines (46 loc) · 1.72 KB
/
submit.py
File metadata and controls
70 lines (46 loc) · 1.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import ast
import datetime as dt
import os
import time
import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import tensorflow as tf
from keras.models import Model, load_model
from tensorflow import keras
from common import *
# os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
start = dt.datetime.now()
model_path = './model/3/weights-001-0.942.hdf5'
print('Loading model at', model_path)
# Load previous model
trained_model = load_model(model_path, custom_objects={
'top_3_accuracy': top_3_accuracy})
# # TTA hflip
# model = TTA_ModelWrapper(trained_model)
print('Loaded model. Predicting')
test = pd.read_csv(os.path.join(INPUT_DIR, 'test_simplified.csv'))
max_load_step = int(test.shape[0] / LOAD_SIZE) + 1
#max_load_step = 4
cats = list_all_categories()
id2cat = {k: cat.replace(' ', '_') for k, cat in enumerate(cats)}
test_predictions = None
for load_step in range(max_load_step):
x_test = df_to_image_array_xd(test, size, load_step)
new_predictions = model.predict(x_test, batch_size=128, verbose=1)
test_predictions = new_predictions if test_predictions is None else np.concatenate(
(test_predictions, new_predictions))
top3 = preds2catids(test_predictions)
top3cats = top3.replace(id2cat)
valid_df = pd.read_csv(os.path.join(
DP_DIR, 'train_k{}.csv.gz'.format(NCSVS - 1)))
map3 = mapk(valid_df[['y']].values, top3.values)
print('Map3: {:.3f}'.format(map3))
test['word'] = top3cats['a'] + ' ' + top3cats['b'] + ' ' + top3cats['c']
submission = test[['key_id', 'word']]
submission.to_csv('gs_mn_submission_{}.csv'.format(
int(map3 * 10**4)), index=False)
end = dt.datetime.now()
print('Latest run {}.\nTotal time {}s'.format(end, (end - start).seconds))