1+ #
2+ # Licensed to the Apache Software Foundation (ASF) under one
3+ # or more contributor license agreements. See the NOTICE file
4+ # distributed with this work for additional information
5+ # regarding copyright ownership. The ASF licenses this file
6+ # to you under the Apache License, Version 2.0 (the
7+ # "License"); you may not use this file except in compliance
8+ # with the License. You may obtain a copy of the License at
9+ #
10+ # http://www.apache.org/licenses/LICENSE-2.0
11+ #
12+ # Unless required by applicable law or agreed to in writing,
13+ # software distributed under the License is distributed on an
14+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+ # KIND, either express or implied. See the License for the
16+ # specific language governing permissions and limitations
17+ # under the License.
18+ #
19+
20+ from singa import layer
21+ from singa import model
22+
23+
24+ class CNN (model .Model ):
25+
26+ def __init__ (self , num_classes = 10 , num_channels = 1 ):
27+ super (CNN , self ).__init__ ()
28+ self .num_classes = num_classes
29+ self .input_size = 28
30+ self .dimension = 4
31+ self .conv1 = layer .Conv2d (num_channels , 20 , 5 , padding = 0 , activation = "RELU" )
32+ self .conv2 = layer .Conv2d (20 , 50 , 5 , padding = 0 , activation = "RELU" )
33+ self .linear1 = layer .Linear (500 )
34+ self .linear2 = layer .Linear (num_classes )
35+ self .pooling1 = layer .MaxPool2d (2 , 2 , padding = 0 )
36+ self .pooling2 = layer .MaxPool2d (2 , 2 , padding = 0 )
37+ self .relu = layer .ReLU ()
38+ self .flatten = layer .Flatten ()
39+ self .softmax_cross_entropy = layer .SoftMaxCrossEntropy ()
40+
41+ def forward (self , x ):
42+ y = self .conv1 (x )
43+ y = self .pooling1 (y )
44+ y = self .conv2 (y )
45+ y = self .pooling2 (y )
46+ y = self .flatten (y )
47+ y = self .linear1 (y )
48+ y = self .relu (y )
49+ y = self .linear2 (y )
50+ return y
51+
52+ def train_one_batch (self , x , y , dist_option , spars ):
53+ out = self .forward (x )
54+ loss = self .softmax_cross_entropy (out , y )
55+
56+ if dist_option == 'plain' :
57+ self .optimizer (loss )
58+ elif dist_option == 'half' :
59+ self .optimizer .backward_and_update_half (loss )
60+ elif dist_option == 'partialUpdate' :
61+ self .optimizer .backward_and_partial_update (loss )
62+ elif dist_option == 'sparseTopK' :
63+ self .optimizer .backward_and_sparse_update (loss ,
64+ topK = True ,
65+ spars = spars )
66+ elif dist_option == 'sparseThreshold' :
67+ self .optimizer .backward_and_sparse_update (loss ,
68+ topK = False ,
69+ spars = spars )
70+ return out , loss
71+
72+ def set_optimizer (self , optimizer ):
73+ self .optimizer = optimizer
74+
75+
76+ def create_model (pretrained = False , ** kwargs ):
77+ """Constructs a CNN model.
78+
79+ Args:
80+ pretrained (bool): If True, returns a pre-trained model.
81+
82+ Returns:
83+ The created CNN model.
84+ """
85+ model = CNN (** kwargs )
86+
87+ return model
88+
89+
90+ __all__ = ['CNN' , 'create_model' ]
0 commit comments