Skip to content

Commit dca0646

Browse files
authored
Merge pull request #14 to fix inference issue
Small hacks to fix #13 due to lack of tags in the inference stage
2 parents 188721d + 9f59121 commit dca0646

File tree

4 files changed

+11
-4
lines changed

4 files changed

+11
-4
lines changed

loader.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,13 @@ def f(x):
148148
chars = [[char_to_id[c] for c in w if c in char_to_id]
149149
for w in str_words]
150150
caps = [cap_feature(w) for w in str_words]
151-
tags = [tag_to_id[w[-1]] for w in s]
151+
152+
# Hack: This is for an inference stage where tag_to_id is not necessary
153+
if tag_to_id:
154+
tags = [tag_to_id[w[-1]] for w in s]
155+
else:
156+
tags = tag_to_id
157+
152158
data.append({
153159
'str_words': str_words,
154160
'words': words,

run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@
6868
file.write('\n'.join(string.split()) + '\n')
6969
file.close()
7070
test_sentences = load_sentences(test_file, lower, zeros)
71-
data = prepare_dataset(test_sentences, word_to_id, char_to_id, lower, True)
71+
data = prepare_dataset(test_sentences, word_to_id, char_to_id, {}, lower, True)
7272

7373
for citation in data:
7474
inputs = create_input(citation, model.parameters, False)

train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@
206206
# Train network
207207
#
208208
singletons = set([word_to_id[k] for k, v in dico_words_train.items() if v == 1])
209-
n_epochs = 10 # number of epochs over the training set
209+
n_epochs = 1 # number of epochs over the training set
210210
freq_eval = 1000 # evaluate on dev every freq_eval steps
211211
best_dev = -np.inf
212212
best_test = -np.inf

utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import codecs
55
import numpy as np
66
import theano
7-
from sklearn import metrics
87

98
models_path = "./models"
109
eval_path = "./evaluation"
@@ -223,6 +222,8 @@ def evaluate(parameters, f_eval, raw_sentences, parsed_sentences,
223222
"""
224223
Evaluate current model using CoNLL script.
225224
"""
225+
# Make sklearn import at runtime only
226+
from sklearn import metrics
226227
results = {'real': [], 'predicted': []}
227228

228229
for _, data in zip(raw_sentences, parsed_sentences):

0 commit comments

Comments
 (0)