Skip to content

Commit 7122002

Browse files
Pandorolucasb-eyer
authored andcommitted
Removed CUDnn requirements from the example model.
The old model is now called lenet_cudnn, the new one is similar, but without the CUDnn convolutions.
1 parent 7ec6e58 commit 7122002

File tree

1 file changed

+24
-1
lines changed

1 file changed

+24
-1
lines changed

examples/MNIST/model.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def net():
1717
return model
1818

1919

20-
def lenet():
20+
def lenet_cudnn():
2121
model = df.Sequential()
2222
model.add(df.Reshape(-1, 1, 28, 28))
2323
model.add(df.SpatialConvolutionCUDNN(1, 32, 5, 5, 1, 1, 2, 2, with_bias=False))
@@ -40,3 +40,26 @@ def lenet():
4040
model.add(df.SoftMax())
4141
return model
4242

43+
44+
def lenet():
45+
model = df.Sequential()
46+
model.add(df.Reshape(-1, 1, 28, 28))
47+
model.add(df.SpatialConvolution(1, 32, 5, 5, 1, 1, with_bias=False))
48+
model.add(df.BatchNormalization(32))
49+
model.add(df.ReLU())
50+
model.add(df.SpatialMaxPooling(2, 2))
51+
52+
model.add(df.SpatialConvolution(32, 64, 5, 5, 1, 1, with_bias=False))
53+
model.add(df.BatchNormalization(64))
54+
model.add(df.ReLU())
55+
model.add(df.SpatialMaxPooling(2, 2))
56+
model.add(df.Reshape(-1, 4*4*64))
57+
58+
model.add(df.Linear(4*4*64, 100, with_bias=False))
59+
model.add(df.BatchNormalization(100))
60+
model.add(df.ReLU())
61+
model.add(df.Dropout(0.5))
62+
63+
model.add(df.Linear(100, 10))
64+
model.add(df.SoftMax())
65+
return model

0 commit comments

Comments
 (0)