@@ -88,6 +88,7 @@ def set_states(self, states):
8888 self .hx .copy_from (states [self .hx .name ])
8989 super ().set_states (states )
9090
91+
9192class Data (object ):
9293
9394 def __init__ (self , fpath , batch_size = 32 , seq_length = 100 , train_ratio = 0.8 ):
@@ -205,4 +206,53 @@ def evaluate(model, data, batch_size, seq_length, dev, inputs, labels):
205206 loss = autograd .softmax_cross_entropy (y , labels )[0 ]
206207 val_loss += tensor .to_numpy (loss )[0 ]
207208 print (' validation loss is %f' %
208- (val_loss / data .num_test_batch / seq_length ))
209+ (val_loss / data .num_test_batch / seq_length ))
210+
211+
212+ def train (data ,
213+ max_epoch ,
214+ hidden_size = 100 ,
215+ seq_length = 100 ,
216+ batch_size = 16 ,
217+ model_path = 'model' ):
218+ # SGD with L2 gradient normalization
219+ cuda = device .create_cuda_gpu ()
220+ model = CharRNN (data .vocab_size , hidden_size )
221+ model .graph (True , False )
222+
223+ inputs , labels = None , None
224+
225+ for epoch in range (max_epoch ):
226+ model .train ()
227+ train_loss = 0
228+ for b in tqdm (range (data .num_train_batch )):
229+ batch = data .train_dat [b * batch_size :(b + 1 ) * batch_size ]
230+ inputs , labels = convert (batch , batch_size , seq_length ,
231+ data .vocab_size , cuda , inputs , labels )
232+ out , loss = model (inputs , labels )
233+ model .reset_states (cuda )
234+ train_loss += tensor .to_numpy (loss )[0 ]
235+
236+ print ('\n Epoch %d, train loss is %f' %
237+ (epoch , train_loss / data .num_train_batch / seq_length ))
238+
239+ evaluate (model , data , batch_size , seq_length , cuda , inputs , labels )
240+ sample (model , data , cuda )
241+
242+
243+ if __name__ == '__main__' :
244+ parser = argparse .ArgumentParser (
245+ description = 'Train multi-stack LSTM for '
246+ 'modeling character sequence from plain text files' )
247+ parser .add_argument ('data' , type = str , help = 'training file' )
248+ parser .add_argument ('-b' , type = int , default = 32 , help = 'batch_size' )
249+ parser .add_argument ('-l' , type = int , default = 64 , help = 'sequence length' )
250+ parser .add_argument ('-d' , type = int , default = 128 , help = 'hidden size' )
251+ parser .add_argument ('-m' , type = int , default = 50 , help = 'max num of epoch' )
252+ args = parser .parse_args ()
253+ data = Data (args .data , batch_size = args .b , seq_length = args .l )
254+ train (data ,
255+ args .m ,
256+ hidden_size = args .d ,
257+ seq_length = args .l ,
258+ batch_size = args .b )
0 commit comments