|
5 | 5 | import json |
6 | 6 | import numpy as np |
7 | 7 | import theano |
8 | | -from gensim.models import KeyedVectors |
| 8 | +from contextlib import closing |
9 | 9 | from utils import evaluate, create_input |
10 | 10 | from model import Model |
11 | 11 | from loader import augment_with_pretrained, load_sentences, prepare_dataset |
|
35 | 35 | opts = optparser.parse_args()[0] |
36 | 36 |
|
37 | 37 | model = Model(model_path=opts.model_path) |
| 38 | +model.parameters['pre_emb'] = os.path.join(os.getcwd(), opts.pre_emb) |
38 | 39 | f = model.build(training=False, **model.parameters) |
39 | | -model.reload() |
40 | | - |
41 | | -model.parameters['pre_emb'] = opts.pre_emb |
42 | | -pretrained = KeyedVectors.load(model.parameters['pre_emb'], mmap='r') |
43 | | -n_words = len(model.id_to_word) |
44 | | - |
45 | | -#only include pretrained embeddings for 640780 most frequent words |
46 | | -words = [item[0] for item in json.load(open('freq', 'r'))] |
47 | | - |
48 | | -#Create new mapping because model.id_to_word only is an Ordered dict of only training and testing data |
49 | | -model.id_to_word = {} |
50 | 40 |
|
51 | | -discarded = 640780 |
52 | | -new_weights = np.empty((n_words - n_words/2 + 1, 500), dtype=theano.config.floatX) |
53 | | -for i in range((n_words/2), n_words): |
54 | | - word = words[i] |
55 | | - lower = word.lower() |
56 | | - digits = re.sub(r'\d', '0', lower) |
57 | | - idx = i - discarded |
58 | | - if word in pretrained: |
59 | | - model.id_to_word[idx] = word |
60 | | - new_weights[idx] = pretrained[word] |
61 | | - elif lower in pretrained: |
62 | | - model.id_to_word[idx] = lower |
63 | | - new_weights[idx] = pretrained[lower] |
64 | | - elif digits in pretrained: |
65 | | - model.id_to_word[idx] = digits |
66 | | - new_weights[idx] = pretrained[digits] |
67 | | - |
68 | | -model.id_to_word[0] = '<UNK>' |
69 | | -#Reset the values of word layer |
70 | | -model.components['word_layer'].embeddings.set_value(new_weights) |
71 | | -#release memory occupied by word embeddings |
72 | | -del pretrained |
73 | | -del new_weights |
| 41 | +model.reload() |
74 | 42 |
|
75 | 43 | lower = model.parameters['lower'] |
76 | 44 | zeros = model.parameters['zeros'] |
|
82 | 50 | if opts.run == 'file': |
83 | 51 | assert opts.input_file |
84 | 52 | assert opts.output_file |
85 | | - input_file = opts.input_file |
| 53 | + |
86 | 54 | output_file = opts.output_file |
87 | | - data = open(input_file, 'r').read() |
| 55 | + |
| 56 | + with closing(open(opts.input_file, 'r')) as fh: |
| 57 | + data = fh.read() |
88 | 58 | strings = data.split('\n') |
89 | 59 | else: |
90 | 60 | string = raw_input("Enter the citation string: ") |
91 | 61 | strings = [string] |
| 62 | + |
92 | 63 | test_file = "test_file" |
93 | 64 | if os.path.exists(test_file): |
94 | 65 | os.remove(test_file) |
95 | 66 | file = open(test_file, 'a') |
96 | 67 | for string in strings: |
97 | | - file.write('\n'.join(string.split())+'\n') |
| 68 | + file.write('\n'.join(string.split()) + '\n') |
98 | 69 | file.close() |
99 | 70 | test_sentences = load_sentences(test_file, lower, zeros) |
100 | 71 | data = prepare_dataset(test_sentences, word_to_id, char_to_id, lower, True) |
| 72 | + |
101 | 73 | for citation in data: |
102 | 74 | inputs = create_input(citation, model.parameters, False) |
103 | 75 | y_pred = np.array(f[1](*inputs))[1:-1] |
104 | | - tags = [] |
105 | | - for i in range(len(y_pred)): |
106 | | - tags.append(model.id_to_tag[y_pred[i]]) |
107 | | - output = [] |
108 | | - for num, word in enumerate(citation['str_words']): |
109 | | - output.append(word+'\t'+tags[num]) |
| 76 | + |
| 77 | + tags = [model.id_to_tag[y_pred[i]] for i in range(len(y_pred))] |
| 78 | + |
| 79 | + output = [w + '\t' + tags[i] for i, w in enumerate(citation['str_words'])] |
| 80 | + |
110 | 81 | if opts.run == 'file': |
111 | | - file = open(output_file, 'w') |
112 | | - file.write('\n'.join(output)) |
113 | | - file.close() |
| 82 | + with closing(open(output_file, 'w')) as fh: |
| 83 | + fh.write('\n'.join(output)) |
114 | 84 | else: |
115 | 85 | print('\n'.join(output)) |
| 86 | + |
116 | 87 | if opts.run == 'file': |
117 | 88 | break |
0 commit comments