Skip to content

Commit 5c62399

Browse files
authored
Merge pull request #1348 from Zrealshadow/dev-postgresql-patch-9
2 parents 4592c75 + 86ae4c1 commit 5c62399

File tree

1 file changed

+51
-1
lines changed

1 file changed

+51
-1
lines changed

examples/singa_peft/examples/model/char_rnn.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def set_states(self, states):
8888
self.hx.copy_from(states[self.hx.name])
8989
super().set_states(states)
9090

91+
9192
class Data(object):
9293

9394
def __init__(self, fpath, batch_size=32, seq_length=100, train_ratio=0.8):
@@ -205,4 +206,53 @@ def evaluate(model, data, batch_size, seq_length, dev, inputs, labels):
205206
loss = autograd.softmax_cross_entropy(y, labels)[0]
206207
val_loss += tensor.to_numpy(loss)[0]
207208
print(' validation loss is %f' %
208-
(val_loss / data.num_test_batch / seq_length))
209+
(val_loss / data.num_test_batch / seq_length))
210+
211+
212+
def train(data,
213+
max_epoch,
214+
hidden_size=100,
215+
seq_length=100,
216+
batch_size=16,
217+
model_path='model'):
218+
# SGD with L2 gradient normalization
219+
cuda = device.create_cuda_gpu()
220+
model = CharRNN(data.vocab_size, hidden_size)
221+
model.graph(True, False)
222+
223+
inputs, labels = None, None
224+
225+
for epoch in range(max_epoch):
226+
model.train()
227+
train_loss = 0
228+
for b in tqdm(range(data.num_train_batch)):
229+
batch = data.train_dat[b * batch_size:(b + 1) * batch_size]
230+
inputs, labels = convert(batch, batch_size, seq_length,
231+
data.vocab_size, cuda, inputs, labels)
232+
out, loss = model(inputs, labels)
233+
model.reset_states(cuda)
234+
train_loss += tensor.to_numpy(loss)[0]
235+
236+
print('\nEpoch %d, train loss is %f' %
237+
(epoch, train_loss / data.num_train_batch / seq_length))
238+
239+
evaluate(model, data, batch_size, seq_length, cuda, inputs, labels)
240+
sample(model, data, cuda)
241+
242+
243+
if __name__ == '__main__':
244+
parser = argparse.ArgumentParser(
245+
description='Train multi-stack LSTM for '
246+
'modeling character sequence from plain text files')
247+
parser.add_argument('data', type=str, help='training file')
248+
parser.add_argument('-b', type=int, default=32, help='batch_size')
249+
parser.add_argument('-l', type=int, default=64, help='sequence length')
250+
parser.add_argument('-d', type=int, default=128, help='hidden size')
251+
parser.add_argument('-m', type=int, default=50, help='max num of epoch')
252+
args = parser.parse_args()
253+
data = Data(args.data, batch_size=args.b, seq_length=args.l)
254+
train(data,
255+
args.m,
256+
hidden_size=args.d,
257+
seq_length=args.l,
258+
batch_size=args.b)

0 commit comments

Comments
 (0)