Skip to content

Commit ef5715e

Browse files
authored
Merge pull request #1317 from zlheui/add_cnn_singa_peft_example
Add the cnn model for the singa peft example
2 parents b6720a9 + e1805fa commit ef5715e

File tree

1 file changed

+90
-0
lines changed
  • examples/singa_peft/examples/model

1 file changed

+90
-0
lines changed
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
#
19+
20+
from singa import layer
21+
from singa import model
22+
23+
24+
class CNN(model.Model):
25+
26+
def __init__(self, num_classes=10, num_channels=1):
27+
super(CNN, self).__init__()
28+
self.num_classes = num_classes
29+
self.input_size = 28
30+
self.dimension = 4
31+
self.conv1 = layer.Conv2d(num_channels, 20, 5, padding=0, activation="RELU")
32+
self.conv2 = layer.Conv2d(20, 50, 5, padding=0, activation="RELU")
33+
self.linear1 = layer.Linear(500)
34+
self.linear2 = layer.Linear(num_classes)
35+
self.pooling1 = layer.MaxPool2d(2, 2, padding=0)
36+
self.pooling2 = layer.MaxPool2d(2, 2, padding=0)
37+
self.relu = layer.ReLU()
38+
self.flatten = layer.Flatten()
39+
self.softmax_cross_entropy = layer.SoftMaxCrossEntropy()
40+
41+
def forward(self, x):
42+
y = self.conv1(x)
43+
y = self.pooling1(y)
44+
y = self.conv2(y)
45+
y = self.pooling2(y)
46+
y = self.flatten(y)
47+
y = self.linear1(y)
48+
y = self.relu(y)
49+
y = self.linear2(y)
50+
return y
51+
52+
def train_one_batch(self, x, y, dist_option, spars):
53+
out = self.forward(x)
54+
loss = self.softmax_cross_entropy(out, y)
55+
56+
if dist_option == 'plain':
57+
self.optimizer(loss)
58+
elif dist_option == 'half':
59+
self.optimizer.backward_and_update_half(loss)
60+
elif dist_option == 'partialUpdate':
61+
self.optimizer.backward_and_partial_update(loss)
62+
elif dist_option == 'sparseTopK':
63+
self.optimizer.backward_and_sparse_update(loss,
64+
topK=True,
65+
spars=spars)
66+
elif dist_option == 'sparseThreshold':
67+
self.optimizer.backward_and_sparse_update(loss,
68+
topK=False,
69+
spars=spars)
70+
return out, loss
71+
72+
def set_optimizer(self, optimizer):
73+
self.optimizer = optimizer
74+
75+
76+
def create_model(pretrained=False, **kwargs):
77+
"""Constructs a CNN model.
78+
79+
Args:
80+
pretrained (bool): If True, returns a pre-trained model.
81+
82+
Returns:
83+
The created CNN model.
84+
"""
85+
model = CNN(**kwargs)
86+
87+
return model
88+
89+
90+
__all__ = ['CNN', 'create_model']

0 commit comments

Comments
 (0)