Skip to content

Commit b8e867b

Browse files
authored
Merge pull request #1345 from gzrp/dev-postgresql
Commit message: Add the sequence model for the peft example
2 parents 3263218 + f3337a7 commit b8e867b

File tree

1 file changed

+89
-0
lines changed

1 file changed

+89
-0
lines changed
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
# =============================================================================
17+
'''Train a Char-RNN model using plain text files.
18+
The model is created following https://github.com/karpathy/char-rnn
19+
The train file could be any text file,
20+
e.g., http://cs.stanford.edu/people/karpathy/char-rnn/
21+
'''
22+
23+
from __future__ import division
24+
from __future__ import print_function
25+
from builtins import range
26+
import numpy as np
27+
import sys
28+
import argparse
29+
from tqdm import tqdm
30+
31+
from singa import device
32+
from singa import tensor
33+
from singa import autograd
34+
from singa import layer
35+
from singa import model
36+
from singa import opt
37+
38+
39+
class CharRNN(model.Model):
40+
41+
def __init__(self, vocab_size, hidden_size=32):
42+
super(CharRNN, self).__init__()
43+
self.rnn = layer.LSTM(vocab_size, hidden_size)
44+
self.cat = layer.Cat()
45+
self.reshape1 = layer.Reshape()
46+
self.dense = layer.Linear(hidden_size, vocab_size)
47+
self.reshape2 = layer.Reshape()
48+
self.softmax_cross_entropy = layer.SoftMaxCrossEntropy()
49+
self.optimizer = opt.SGD(0.01)
50+
self.hidden_size = hidden_size
51+
self.vocab_size = vocab_size
52+
53+
def reset_states(self, dev):
54+
self.hx.to_device(dev)
55+
self.cx.to_device(dev)
56+
self.hx.set_value(0.0)
57+
self.cx.set_value(0.0)
58+
59+
def initialize(self, inputs):
60+
batchsize = inputs[0].shape[0]
61+
self.hx = tensor.Tensor((batchsize, self.hidden_size))
62+
self.cx = tensor.Tensor((batchsize, self.hidden_size))
63+
self.reset_states(inputs[0].device)
64+
65+
def forward(self, inputs):
66+
x, hx, cx = self.rnn(inputs, (self.hx, self.cx))
67+
self.hx.copy_data(hx)
68+
self.cx.copy_data(cx)
69+
x = self.cat(x)
70+
x = self.reshape1(x, (-1, self.hidden_size))
71+
return self.dense(x)
72+
73+
def train_one_batch(self, x, y):
74+
out = self.forward(x)
75+
y = self.reshape2(y, (-1, 1))
76+
loss = self.softmax_cross_entropy(out, y)
77+
self.optimizer(loss)
78+
return out, loss
79+
80+
def get_states(self):
81+
ret = super().get_states()
82+
ret[self.hx.name] = self.hx
83+
ret[self.cx.name] = self.cx
84+
return ret
85+
86+
def set_states(self, states):
87+
self.hx.copy_from(states[self.hx.name])
88+
self.hx.copy_from(states[self.hx.name])
89+
super().set_states(states)

0 commit comments

Comments
 (0)