Skip to content

Commit 892eca0

Browse files
sumantro93dvrogozh
andauthored
Add Swin Transformer Example (#1346)
* add swin transformer example * add swin transformer example * with accelerator API * fixes requirements,code and readme * Update run_python_examples.sh Co-authored-by: Dmitry Rogozhkin <[email protected]> * Update swin_transformer/README.md Co-authored-by: Dmitry Rogozhkin <[email protected]> --------- Co-authored-by: Dmitry Rogozhkin <[email protected]>
1 parent a630ec6 commit 892eca0

File tree

4 files changed

+273
-2
lines changed

4 files changed

+273
-2
lines changed

run_python_examples.sh

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,10 @@ function gat() {
171171
uv run main.py --epochs 1 --dry-run || error "graph attention network failed"
172172
}
173173

174+
function swin() {
175+
uv run swin_transformer.py --epochs 1 --dry-run || error "swin transformer failed"
176+
}
177+
174178
eval "base_$(declare -f stop)"
175179

176180
function stop() {
@@ -195,8 +199,8 @@ function stop() {
195199
time_sequence_prediction/traindata.pt \
196200
word_language_model/model.pt \
197201
gcn/cora/ \
198-
gat/cora/ || error "couldn't clean up some files"
199-
202+
gat/cora/ \
203+
swin_trasformer/swin_cifar10.pt || error "couldn't clean up some files"
200204
git checkout fast_neural_style/images/output-images/amber-candy.jpg || error "couldn't clean up fast neural style image"
201205

202206
base_stop "$1"
@@ -224,6 +228,7 @@ function run_all() {
224228
run fx
225229
run gcn
226230
run gat
231+
run swin_transformer
227232
}
228233

229234
# by default, run all examples

swin_transformer/README.md

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Swin Transformer on CIFAR-10
2+
3+
This project demonstrates a minimal implementation of a **Swin Transformer** for image classification on the **CIFAR-10** dataset using PyTorch.
4+
5+
It includes:
6+
- Patch embedding and window-based self-attention
7+
- Shifted windows for hierarchical representation
8+
- Training and testing logic using standard PyTorch utilities
9+
10+
---
11+
12+
## Files
13+
14+
- `swin_transformer.py` — Full implementation of the Swin Transformer model, training loop, and evaluation on CIFAR-10.
15+
- `README.md` — This file.
16+
17+
---
18+
19+
## Requirements
20+
21+
- Python 3.8+
22+
- PyTorch 2.6 or later
23+
- `torchvision` (for CIFAR-10 dataset)
24+
25+
Install dependencies:
26+
27+
```bash
28+
pip install -r requirements.txt
29+
```
30+
31+
---
32+
33+
## Usage
34+
35+
### Train & Save the model
36+
37+
```bash
38+
python swin_transformer.py --epochs 10 --batch-size 64 --lr 0.001 --save-model
39+
```
40+
41+
### Test the model
42+
43+
Testing is done automatically after each epoch. To only test, run with:
44+
45+
```bash
46+
python swin_transformer.py --epochs 1
47+
``
48+
49+
The model will be saved as `swin_cifar10.pt`.
50+
51+
---
52+
53+
## Features
54+
55+
- Uses shifted window attention for local self-attention.
56+
- Patch-based embedding with a lightweight network.
57+
- Trains on CIFAR-10 with `Adam` optimizer and learning rate scheduling.
58+
- Prints loss and accuracy per epoch.
59+
60+
---
61+

swin_transformer/requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
torch>=2.6
2+
torchvision

swin_transformer/swin_transformer.py

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
import argparse
2+
import torch
3+
import torch.nn as nn
4+
import torch.nn.functional as F
5+
import torch.optim as optim
6+
from torch.optim.lr_scheduler import StepLR
7+
from torchvision import datasets, transforms
8+
9+
# ---------- Core Swin Components ----------
10+
11+
class PatchEmbed(nn.Module):
12+
def __init__(self, img_size=32, patch_size=4, in_chans=3, embed_dim=48):
13+
super().__init__()
14+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
15+
self.norm = nn.LayerNorm(embed_dim)
16+
17+
def forward(self, x):
18+
x = self.proj(x)
19+
x = x.flatten(2).transpose(1, 2)
20+
x = self.norm(x)
21+
return x
22+
23+
def window_partition(x, window_size):
24+
B, H, W, C = x.shape
25+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
26+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
27+
return windows
28+
29+
def window_reverse(windows, window_size, H, W):
30+
B = int(windows.shape[0] / (H * W / window_size / window_size))
31+
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
32+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
33+
return x
34+
35+
class WindowAttention(nn.Module):
36+
def __init__(self, dim, window_size, num_heads):
37+
super().__init__()
38+
self.num_heads = num_heads
39+
head_dim = dim // num_heads
40+
self.scale = head_dim ** -0.5
41+
42+
self.qkv = nn.Linear(dim, dim * 3)
43+
self.proj = nn.Linear(dim, dim)
44+
45+
def forward(self, x):
46+
B_, N, C = x.shape
47+
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads)
48+
q, k, v = qkv.permute(2, 0, 3, 1, 4)
49+
50+
attn = (q @ k.transpose(-2, -1)) * self.scale
51+
attn = attn.softmax(dim=-1)
52+
53+
out = (attn @ v).transpose(1, 2).reshape(B_, N, C)
54+
return self.proj(out)
55+
56+
class SwinTransformerBlock(nn.Module):
57+
def __init__(self, dim, input_resolution, num_heads, window_size=4, shift_size=0):
58+
super().__init__()
59+
self.dim = dim
60+
self.input_resolution = input_resolution
61+
self.window_size = window_size
62+
self.shift_size = shift_size
63+
64+
self.norm1 = nn.LayerNorm(dim)
65+
self.attn = WindowAttention(dim, window_size, num_heads)
66+
self.norm2 = nn.LayerNorm(dim)
67+
68+
self.mlp = nn.Sequential(
69+
nn.Linear(dim, dim * 4),
70+
nn.GELU(),
71+
nn.Linear(dim * 4, dim)
72+
)
73+
74+
def forward(self, x):
75+
H, W = self.input_resolution
76+
B, L, C = x.shape
77+
x = x.view(B, H, W, C)
78+
79+
if self.shift_size > 0:
80+
shifted_x = torch.roll(x, (-self.shift_size, -self.shift_size), (1, 2))
81+
else:
82+
shifted_x = x
83+
84+
windows = window_partition(shifted_x, self.window_size)
85+
windows = windows.view(-1, self.window_size * self.window_size, C)
86+
87+
attn_windows = self.attn(self.norm1(windows))
88+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
89+
90+
shifted_x = window_reverse(attn_windows, self.window_size, H, W)
91+
92+
if self.shift_size > 0:
93+
x = torch.roll(shifted_x, (self.shift_size, self.shift_size), (1, 2))
94+
else:
95+
x = shifted_x
96+
97+
x = x.view(B, H * W, C)
98+
x = x + self.mlp(self.norm2(x))
99+
return x
100+
101+
# ---------- Final Network ----------
102+
103+
class SwinTinyNet(nn.Module):
104+
def __init__(self, num_classes=10):
105+
super(SwinTinyNet, self).__init__()
106+
self.patch_embed = PatchEmbed(img_size=32, patch_size=4, in_chans=3, embed_dim=48)
107+
self.block1 = SwinTransformerBlock(dim=48, input_resolution=(8, 8), num_heads=3, window_size=4, shift_size=0)
108+
self.block2 = SwinTransformerBlock(dim=48, input_resolution=(8, 8), num_heads=3, window_size=4, shift_size=2)
109+
self.norm = nn.LayerNorm(48)
110+
self.fc = nn.Linear(48, num_classes)
111+
112+
def forward(self, x):
113+
x = self.patch_embed(x)
114+
x = self.block1(x)
115+
x = self.block2(x)
116+
x = self.norm(x)
117+
x = x.mean(dim=1)
118+
x = self.fc(x)
119+
return F.log_softmax(x, dim=1)
120+
121+
# ---------- Training and Testing ----------
122+
123+
def train(args, model, device, train_loader, optimizer, epoch):
124+
model.train()
125+
for batch_idx, (data, target) in enumerate(train_loader):
126+
data, target = data.to(device), target.to(device)
127+
optimizer.zero_grad()
128+
output = model(data)
129+
loss = F.nll_loss(output, target)
130+
loss.backward()
131+
optimizer.step()
132+
if batch_idx % args.log_interval == 0:
133+
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
134+
epoch, batch_idx * len(data), len(train_loader.dataset),
135+
100. * batch_idx / len(train_loader), loss.item()))
136+
if args.dry_run:
137+
break
138+
139+
def test(args, model, device, test_loader):
140+
model.eval()
141+
test_loss = 0
142+
correct = 0
143+
with torch.no_grad():
144+
for data, target in test_loader:
145+
data, target = data.to(device), target.to(device)
146+
output = model(data)
147+
test_loss += F.nll_loss(output, target, reduction='sum').item()
148+
pred = output.argmax(dim=1, keepdim=True)
149+
correct += pred.eq(target.view_as(pred)).sum().item()
150+
if args.dry_run:
151+
break
152+
153+
test_loss /= len(test_loader.dataset)
154+
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
155+
test_loss, correct, len(test_loader.dataset),
156+
100. * correct / len(test_loader.dataset)))
157+
158+
# ---------- Main ----------
159+
160+
def main():
161+
parser = argparse.ArgumentParser(description='Swin Transformer CIFAR10 Example')
162+
parser.add_argument('--batch-size', type=int, default=64)
163+
parser.add_argument('--test-batch-size', type=int, default=1000)
164+
parser.add_argument('--epochs', type=int, default=10)
165+
parser.add_argument('--lr', type=float, default=0.01)
166+
parser.add_argument('--gamma', type=float, default=0.7)
167+
parser.add_argument('--dry-run', action='store_true')
168+
parser.add_argument('--seed', type=int, default=42)
169+
parser.add_argument('--log-interval', type=int, default=10)
170+
parser.add_argument('--save-model', action='store_true')
171+
args = parser.parse_args()
172+
173+
use_accel = torch.accelerator.is_available()
174+
device = torch.accelerator.current_accelerator() if use_accel else torch.device("cpu")
175+
print(f"Using device: {device}")
176+
177+
torch.manual_seed(args.seed)
178+
179+
transform = transforms.Compose([
180+
transforms.ToTensor(),
181+
transforms.Normalize((0.5,), (0.5,))
182+
])
183+
184+
train_loader = torch.utils.data.DataLoader(
185+
datasets.CIFAR10('../data', train=True, download=True, transform=transform),
186+
batch_size=args.batch_size, shuffle=True)
187+
188+
test_loader = torch.utils.data.DataLoader(
189+
datasets.CIFAR10('../data', train=False, transform=transform),
190+
batch_size=args.test_batch_size, shuffle=False)
191+
192+
model = SwinTinyNet().to(device)
193+
optimizer = optim.Adam(model.parameters(), lr=args.lr)
194+
scheduler = StepLR(optimizer, step_size=3, gamma=args.gamma)
195+
196+
for epoch in range(1, args.epochs + 1):
197+
train(args, model, device, train_loader, optimizer, epoch)
198+
test(args, model, device, test_loader)
199+
scheduler.step()
200+
201+
if args.save_model:
202+
torch.save(model.state_dict(), "swin_cifar10.pt")
203+
main()

0 commit comments

Comments
 (0)