1
+ import argparse
2
+ import torch
3
+ import torch .nn as nn
4
+ import torch .nn .functional as F
5
+ import torch .optim as optim
6
+ from torch .optim .lr_scheduler import StepLR
7
+ from torchvision import datasets , transforms
8
+
9
+ # ---------- Core Swin Components ----------
10
+
11
+ class PatchEmbed (nn .Module ):
12
+ def __init__ (self , img_size = 32 , patch_size = 4 , in_chans = 3 , embed_dim = 48 ):
13
+ super ().__init__ ()
14
+ self .proj = nn .Conv2d (in_chans , embed_dim , kernel_size = patch_size , stride = patch_size )
15
+ self .norm = nn .LayerNorm (embed_dim )
16
+
17
+ def forward (self , x ):
18
+ x = self .proj (x )
19
+ x = x .flatten (2 ).transpose (1 , 2 )
20
+ x = self .norm (x )
21
+ return x
22
+
23
+ def window_partition (x , window_size ):
24
+ B , H , W , C = x .shape
25
+ x = x .view (B , H // window_size , window_size , W // window_size , window_size , C )
26
+ windows = x .permute (0 , 1 , 3 , 2 , 4 , 5 ).contiguous ().view (- 1 , window_size , window_size , C )
27
+ return windows
28
+
29
+ def window_reverse (windows , window_size , H , W ):
30
+ B = int (windows .shape [0 ] / (H * W / window_size / window_size ))
31
+ x = windows .view (B , H // window_size , W // window_size , window_size , window_size , - 1 )
32
+ x = x .permute (0 , 1 , 3 , 2 , 4 , 5 ).contiguous ().view (B , H , W , - 1 )
33
+ return x
34
+
35
+ class WindowAttention (nn .Module ):
36
+ def __init__ (self , dim , window_size , num_heads ):
37
+ super ().__init__ ()
38
+ self .num_heads = num_heads
39
+ head_dim = dim // num_heads
40
+ self .scale = head_dim ** - 0.5
41
+
42
+ self .qkv = nn .Linear (dim , dim * 3 )
43
+ self .proj = nn .Linear (dim , dim )
44
+
45
+ def forward (self , x ):
46
+ B_ , N , C = x .shape
47
+ qkv = self .qkv (x ).reshape (B_ , N , 3 , self .num_heads , C // self .num_heads )
48
+ q , k , v = qkv .permute (2 , 0 , 3 , 1 , 4 )
49
+
50
+ attn = (q @ k .transpose (- 2 , - 1 )) * self .scale
51
+ attn = attn .softmax (dim = - 1 )
52
+
53
+ out = (attn @ v ).transpose (1 , 2 ).reshape (B_ , N , C )
54
+ return self .proj (out )
55
+
56
+ class SwinTransformerBlock (nn .Module ):
57
+ def __init__ (self , dim , input_resolution , num_heads , window_size = 4 , shift_size = 0 ):
58
+ super ().__init__ ()
59
+ self .dim = dim
60
+ self .input_resolution = input_resolution
61
+ self .window_size = window_size
62
+ self .shift_size = shift_size
63
+
64
+ self .norm1 = nn .LayerNorm (dim )
65
+ self .attn = WindowAttention (dim , window_size , num_heads )
66
+ self .norm2 = nn .LayerNorm (dim )
67
+
68
+ self .mlp = nn .Sequential (
69
+ nn .Linear (dim , dim * 4 ),
70
+ nn .GELU (),
71
+ nn .Linear (dim * 4 , dim )
72
+ )
73
+
74
+ def forward (self , x ):
75
+ H , W = self .input_resolution
76
+ B , L , C = x .shape
77
+ x = x .view (B , H , W , C )
78
+
79
+ if self .shift_size > 0 :
80
+ shifted_x = torch .roll (x , (- self .shift_size , - self .shift_size ), (1 , 2 ))
81
+ else :
82
+ shifted_x = x
83
+
84
+ windows = window_partition (shifted_x , self .window_size )
85
+ windows = windows .view (- 1 , self .window_size * self .window_size , C )
86
+
87
+ attn_windows = self .attn (self .norm1 (windows ))
88
+ attn_windows = attn_windows .view (- 1 , self .window_size , self .window_size , C )
89
+
90
+ shifted_x = window_reverse (attn_windows , self .window_size , H , W )
91
+
92
+ if self .shift_size > 0 :
93
+ x = torch .roll (shifted_x , (self .shift_size , self .shift_size ), (1 , 2 ))
94
+ else :
95
+ x = shifted_x
96
+
97
+ x = x .view (B , H * W , C )
98
+ x = x + self .mlp (self .norm2 (x ))
99
+ return x
100
+
101
+ # ---------- Final Network ----------
102
+
103
+ class SwinTinyNet (nn .Module ):
104
+ def __init__ (self , num_classes = 10 ):
105
+ super (SwinTinyNet , self ).__init__ ()
106
+ self .patch_embed = PatchEmbed (img_size = 32 , patch_size = 4 , in_chans = 3 , embed_dim = 48 )
107
+ self .block1 = SwinTransformerBlock (dim = 48 , input_resolution = (8 , 8 ), num_heads = 3 , window_size = 4 , shift_size = 0 )
108
+ self .block2 = SwinTransformerBlock (dim = 48 , input_resolution = (8 , 8 ), num_heads = 3 , window_size = 4 , shift_size = 2 )
109
+ self .norm = nn .LayerNorm (48 )
110
+ self .fc = nn .Linear (48 , num_classes )
111
+
112
+ def forward (self , x ):
113
+ x = self .patch_embed (x )
114
+ x = self .block1 (x )
115
+ x = self .block2 (x )
116
+ x = self .norm (x )
117
+ x = x .mean (dim = 1 )
118
+ x = self .fc (x )
119
+ return F .log_softmax (x , dim = 1 )
120
+
121
+ # ---------- Training and Testing ----------
122
+
123
+ def train (args , model , device , train_loader , optimizer , epoch ):
124
+ model .train ()
125
+ for batch_idx , (data , target ) in enumerate (train_loader ):
126
+ data , target = data .to (device ), target .to (device )
127
+ optimizer .zero_grad ()
128
+ output = model (data )
129
+ loss = F .nll_loss (output , target )
130
+ loss .backward ()
131
+ optimizer .step ()
132
+ if batch_idx % args .log_interval == 0 :
133
+ print ('Train Epoch: {} [{}/{} ({:.0f}%)]\t Loss: {:.6f}' .format (
134
+ epoch , batch_idx * len (data ), len (train_loader .dataset ),
135
+ 100. * batch_idx / len (train_loader ), loss .item ()))
136
+ if args .dry_run :
137
+ break
138
+
139
+ def test (args , model , device , test_loader ):
140
+ model .eval ()
141
+ test_loss = 0
142
+ correct = 0
143
+ with torch .no_grad ():
144
+ for data , target in test_loader :
145
+ data , target = data .to (device ), target .to (device )
146
+ output = model (data )
147
+ test_loss += F .nll_loss (output , target , reduction = 'sum' ).item ()
148
+ pred = output .argmax (dim = 1 , keepdim = True )
149
+ correct += pred .eq (target .view_as (pred )).sum ().item ()
150
+ if args .dry_run :
151
+ break
152
+
153
+ test_loss /= len (test_loader .dataset )
154
+ print ('\n Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n ' .format (
155
+ test_loss , correct , len (test_loader .dataset ),
156
+ 100. * correct / len (test_loader .dataset )))
157
+
158
+ # ---------- Main ----------
159
+
160
+ def main ():
161
+ parser = argparse .ArgumentParser (description = 'Swin Transformer CIFAR10 Example' )
162
+ parser .add_argument ('--batch-size' , type = int , default = 64 )
163
+ parser .add_argument ('--test-batch-size' , type = int , default = 1000 )
164
+ parser .add_argument ('--epochs' , type = int , default = 10 )
165
+ parser .add_argument ('--lr' , type = float , default = 0.01 )
166
+ parser .add_argument ('--gamma' , type = float , default = 0.7 )
167
+ parser .add_argument ('--dry-run' , action = 'store_true' )
168
+ parser .add_argument ('--seed' , type = int , default = 42 )
169
+ parser .add_argument ('--log-interval' , type = int , default = 10 )
170
+ parser .add_argument ('--save-model' , action = 'store_true' )
171
+ args = parser .parse_args ()
172
+
173
+ use_accel = torch .accelerator .is_available ()
174
+ device = torch .accelerator .current_accelerator () if use_accel else torch .device ("cpu" )
175
+ print (f"Using device: { device } " )
176
+
177
+ torch .manual_seed (args .seed )
178
+
179
+ transform = transforms .Compose ([
180
+ transforms .ToTensor (),
181
+ transforms .Normalize ((0.5 ,), (0.5 ,))
182
+ ])
183
+
184
+ train_loader = torch .utils .data .DataLoader (
185
+ datasets .CIFAR10 ('../data' , train = True , download = True , transform = transform ),
186
+ batch_size = args .batch_size , shuffle = True )
187
+
188
+ test_loader = torch .utils .data .DataLoader (
189
+ datasets .CIFAR10 ('../data' , train = False , transform = transform ),
190
+ batch_size = args .test_batch_size , shuffle = False )
191
+
192
+ model = SwinTinyNet ().to (device )
193
+ optimizer = optim .Adam (model .parameters (), lr = args .lr )
194
+ scheduler = StepLR (optimizer , step_size = 3 , gamma = args .gamma )
195
+
196
+ for epoch in range (1 , args .epochs + 1 ):
197
+ train (args , model , device , train_loader , optimizer , epoch )
198
+ test (args , model , device , test_loader )
199
+ scheduler .step ()
200
+
201
+ if args .save_model :
202
+ torch .save (model .state_dict (), "swin_cifar10.pt" )
203
+ main ()
0 commit comments