1+ # Licensed to the Apache Software Foundation (ASF) under one
2+ # or more contributor license agreements. See the NOTICE file
3+ # distributed with this work for additional information
4+ # regarding copyright ownership. The ASF licenses this file
5+ # to you under the Apache License, Version 2.0 (the
6+ # "License"); you may not use this file except in compliance
7+ # with the License. You may obtain a copy of the License at
8+ #
9+ # http://www.apache.org/licenses/LICENSE-2.0
10+ #
11+ # Unless required by applicable law or agreed to in writing, software
12+ # distributed under the License is distributed on an "AS IS" BASIS,
13+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+ # See the License for the specific language governing permissions and
15+ # limitations under the License.
16+ # =============================================================================
17+ '''Train a Char-RNN model using plain text files.
18+ The model is created following https://github.com/karpathy/char-rnn
19+ The train file could be any text file,
20+ e.g., http://cs.stanford.edu/people/karpathy/char-rnn/
21+ '''
22+
23+ from __future__ import division
24+ from __future__ import print_function
25+ from builtins import range
26+ import numpy as np
27+ import sys
28+ import argparse
29+ from tqdm import tqdm
30+
31+ from singa import device
32+ from singa import tensor
33+ from singa import autograd
34+ from singa import layer
35+ from singa import model
36+ from singa import opt
37+
38+
39+ class CharRNN (model .Model ):
40+
41+ def __init__ (self , vocab_size , hidden_size = 32 ):
42+ super (CharRNN , self ).__init__ ()
43+ self .rnn = layer .LSTM (vocab_size , hidden_size )
44+ self .cat = layer .Cat ()
45+ self .reshape1 = layer .Reshape ()
46+ self .dense = layer .Linear (hidden_size , vocab_size )
47+ self .reshape2 = layer .Reshape ()
48+ self .softmax_cross_entropy = layer .SoftMaxCrossEntropy ()
49+ self .optimizer = opt .SGD (0.01 )
50+ self .hidden_size = hidden_size
51+ self .vocab_size = vocab_size
52+
53+ def reset_states (self , dev ):
54+ self .hx .to_device (dev )
55+ self .cx .to_device (dev )
56+ self .hx .set_value (0.0 )
57+ self .cx .set_value (0.0 )
58+
59+ def initialize (self , inputs ):
60+ batchsize = inputs [0 ].shape [0 ]
61+ self .hx = tensor .Tensor ((batchsize , self .hidden_size ))
62+ self .cx = tensor .Tensor ((batchsize , self .hidden_size ))
63+ self .reset_states (inputs [0 ].device )
64+
65+ def forward (self , inputs ):
66+ x , hx , cx = self .rnn (inputs , (self .hx , self .cx ))
67+ self .hx .copy_data (hx )
68+ self .cx .copy_data (cx )
69+ x = self .cat (x )
70+ x = self .reshape1 (x , (- 1 , self .hidden_size ))
71+ return self .dense (x )
72+
73+ def train_one_batch (self , x , y ):
74+ out = self .forward (x )
75+ y = self .reshape2 (y , (- 1 , 1 ))
76+ loss = self .softmax_cross_entropy (out , y )
77+ self .optimizer (loss )
78+ return out , loss
79+
80+ def get_states (self ):
81+ ret = super ().get_states ()
82+ ret [self .hx .name ] = self .hx
83+ ret [self .cx .name ] = self .cx
84+ return ret
85+
86+ def set_states (self , states ):
87+ self .hx .copy_from (states [self .hx .name ])
88+ self .hx .copy_from (states [self .hx .name ])
89+ super ().set_states (states )
0 commit comments