10
10
11
11
from model .model import CharRNN
12
12
from model .vocab import START_CHAR , END_CHAR
13
- from train import get_vocab_from_file
13
+ from model . vocab import get_vocab_from_file
14
14
15
15
file_path = os .path .dirname (os .path .realpath (__file__ ))
16
16
lib_path = os .path .abspath (os .path .join (file_path , '..' ))
70
70
'output' ,
71
71
'input' ,
72
72
'nsamples' ,
73
+ 'model' ,
73
74
]
74
75
75
76
@@ -175,6 +176,18 @@ def run(params):
175
176
print ("Note: This script is very picky. Please check device output to see where this is running. " )
176
177
args = candle .ArgumentStruct (** params )
177
178
179
+ data_url = args .data_url
180
+
181
+ if args .model == 'ft_goodperforming_model.pt' :
182
+ file = 'pilot1/ft_goodperforming_model.pt'
183
+ elif args .model == 'ft_poorperforming_model.pt' :
184
+ file = 'pilot1/ft_poorperforming_model.pt'
185
+ else : # Corresponding to args.model == 'autosave.model.pt':
186
+ file = 'mosesrun/autosave.model.pt'
187
+
188
+ print ('Recovering trained model' )
189
+ trained = candle .fetch_file (data_url + file , subdir = 'examples/rnngen' )
190
+
178
191
# Configure GPU
179
192
if args .use_gpus and torch .cuda .is_available ():
180
193
device = 'cuda'
@@ -188,10 +201,12 @@ def run(params):
188
201
model = CharRNN (len (vocab ), len (vocab ), max_len = args .maxlen ).to (device )
189
202
optimizer = torch .optim .Adam (model .parameters (), lr = args .learning_rate )
190
203
191
- pt = torch .load (args .logdir + "/" + args .model , map_location = device )
204
+ print ("Loading trained model." )
205
+ pt = torch .load (trained , map_location = device )
192
206
model .load_state_dict (pt ['state_dict' ])
193
207
optimizer .load_state_dict (pt ['optim_state_dict' ])
194
208
209
+ print ("Applying to loaded data" )
195
210
total_sampled = 0
196
211
total_valid = 0
197
212
total_unqiue = 0
0 commit comments