1+ import torch .nn as nn
2+ import torch
3+
4+ NumF = 12
5+ numHops = 98
6+ timePoolSize = 13
7+ dropoutProb = 0.2
8+ numClasses = 11
9+
10+ class CNN (nn .Module ):
11+
12+ # Contructor
13+ def __init__ (self , out_1 = NumF ):
14+ super (CNN , self ).__init__ ()
15+ self .cnn1 = nn .Conv2d (in_channels = 1 , out_channels = out_1 , kernel_size = 3 , padding = 1 )
16+ self .batch1 = nn .BatchNorm2d (out_1 )
17+ self .relu1 = nn .ReLU ()
18+
19+ self .maxpool1 = nn .MaxPool2d (kernel_size = 3 , stride = 2 , padding = 1 )
20+
21+ self .cnn2 = nn .Conv2d (in_channels = out_1 , out_channels = 2 * out_1 , kernel_size = 3 , padding = 1 )
22+ self .batch2 = nn .BatchNorm2d (2 * out_1 )
23+ self .relu2 = nn .ReLU ()
24+
25+ self .maxpool2 = nn .MaxPool2d (kernel_size = 3 , stride = 2 , padding = 1 )
26+
27+ self .cnn3 = nn .Conv2d (in_channels = 2 * out_1 , out_channels = 4 * out_1 , kernel_size = 3 , padding = 1 )
28+ self .batch3 = nn .BatchNorm2d (4 * out_1 )
29+ self .relu3 = nn .ReLU ()
30+
31+ self .maxpool3 = nn .MaxPool2d (kernel_size = 3 , stride = 2 , padding = 1 )
32+
33+ self .cnn4 = nn .Conv2d (in_channels = 4 * out_1 , out_channels = 4 * out_1 , kernel_size = 3 , padding = 1 )
34+ self .batch4 = nn .BatchNorm2d (4 * out_1 )
35+ self .relu4 = nn .ReLU ()
36+ self .cnn5 = nn .Conv2d (in_channels = 4 * out_1 , out_channels = 4 * out_1 , kernel_size = 3 , padding = 1 )
37+ self .batch5 = nn .BatchNorm2d (4 * out_1 )
38+ self .relu5 = nn .ReLU ()
39+
40+ self .maxpool4 = nn .MaxPool2d (kernel_size = (timePoolSize , 1 ))
41+
42+ self .dropout = nn .Dropout2d (dropoutProb )
43+
44+ self .fc = nn .Linear (336 , numClasses )
45+
46+ # Prediction
47+ def forward (self , x ):
48+
49+ out = self .cnn1 (x )
50+ out = self .batch1 (out )
51+ out = self .relu1 (out )
52+
53+ out = self .maxpool1 (out )
54+
55+ out = self .cnn2 (out )
56+ out = self .batch2 (out )
57+ out = self .relu2 (out )
58+
59+ out = self .maxpool2 (out )
60+
61+ out = self .cnn3 (out )
62+ out = self .batch3 (out )
63+ out = self .relu3 (out )
64+
65+ out = self .maxpool3 (out )
66+
67+ out = self .cnn4 (out )
68+ out = self .batch4 (out )
69+ out = self .relu4 (out )
70+ out = self .cnn5 (out )
71+ out = self .batch5 (out )
72+ out = self .relu5 (out )
73+
74+ out = self .maxpool4 (out )
75+
76+ out = self .dropout (out )
77+
78+ out = out .view (out .size (0 ), - 1 )
79+ out = self .fc (out )
80+
81+ return out
0 commit comments