diff --git a/run_python_examples.sh b/run_python_examples.sh index 9544d2ed8d..0a1452e1c0 100755 --- a/run_python_examples.sh +++ b/run_python_examples.sh @@ -171,10 +171,6 @@ function gat() { uv run main.py --epochs 1 --dry-run || error "graph attention network failed" } -function swin() { - uv run swin_transformer.py --epochs 1 --dry-run || error "swin transformer failed" -} - eval "base_$(declare -f stop)" function stop() { @@ -199,8 +195,8 @@ function stop() { time_sequence_prediction/traindata.pt \ word_language_model/model.pt \ gcn/cora/ \ - gat/cora/ \ - swin_trasformer/swin_cifar10.pt || error "couldn't clean up some files" + gat/cora/ || error "couldn't clean up some files" + git checkout fast_neural_style/images/output-images/amber-candy.jpg || error "couldn't clean up fast neural style image" base_stop "$1" @@ -228,7 +224,6 @@ function run_all() { run fx run gcn run gat - run swin_transformer } # by default, run all examples diff --git a/swin_transformer/README.md b/swin_transformer/README.md deleted file mode 100644 index 37f789be37..0000000000 --- a/swin_transformer/README.md +++ /dev/null @@ -1,61 +0,0 @@ -# Swin Transformer on CIFAR-10 - -This project demonstrates a minimal implementation of a **Swin Transformer** for image classification on the **CIFAR-10** dataset using PyTorch. - -It includes: -- Patch embedding and window-based self-attention -- Shifted windows for hierarchical representation -- Training and testing logic using standard PyTorch utilities - ---- - -## Files - -- `swin_transformer.py` — Full implementation of the Swin Transformer model, training loop, and evaluation on CIFAR-10. -- `README.md` — This file. - ---- - -## Requirements - -- Python 3.8+ -- PyTorch 2.6 or later -- `torchvision` (for CIFAR-10 dataset) - -Install dependencies: - -```bash -pip install -r requirements.txt -``` - ---- - -## Usage - -### Train & Save the model - -```bash -python swin_transformer.py --epochs 10 --batch-size 64 --lr 0.001 --save-model -``` - -### Test the model - -Testing is done automatically after each epoch. To only test, run with: - -```bash -python swin_transformer.py --epochs 1 -`` - -The model will be saved as `swin_cifar10.pt`. - ---- - -## Features - -- Uses shifted window attention for local self-attention. -- Patch-based embedding with a lightweight network. -- Trains on CIFAR-10 with `Adam` optimizer and learning rate scheduling. -- Prints loss and accuracy per epoch. - ---- - diff --git a/swin_transformer/requirements.txt b/swin_transformer/requirements.txt deleted file mode 100644 index 9a083ba390..0000000000 --- a/swin_transformer/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -torch>=2.6 -torchvision diff --git a/swin_transformer/swin_transformer.py b/swin_transformer/swin_transformer.py deleted file mode 100644 index a29fbd5fff..0000000000 --- a/swin_transformer/swin_transformer.py +++ /dev/null @@ -1,203 +0,0 @@ -import argparse -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -from torch.optim.lr_scheduler import StepLR -from torchvision import datasets, transforms - -# ---------- Core Swin Components ---------- - -class PatchEmbed(nn.Module): - def __init__(self, img_size=32, patch_size=4, in_chans=3, embed_dim=48): - super().__init__() - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) - self.norm = nn.LayerNorm(embed_dim) - - def forward(self, x): - x = self.proj(x) - x = x.flatten(2).transpose(1, 2) - x = self.norm(x) - return x - -def window_partition(x, window_size): - B, H, W, C = x.shape - x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) - return windows - -def window_reverse(windows, window_size, H, W): - B = int(windows.shape[0] / (H * W / window_size / window_size)) - x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) - return x - -class WindowAttention(nn.Module): - def __init__(self, dim, window_size, num_heads): - super().__init__() - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = head_dim ** -0.5 - - self.qkv = nn.Linear(dim, dim * 3) - self.proj = nn.Linear(dim, dim) - - def forward(self, x): - B_, N, C = x.shape - qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads) - q, k, v = qkv.permute(2, 0, 3, 1, 4) - - attn = (q @ k.transpose(-2, -1)) * self.scale - attn = attn.softmax(dim=-1) - - out = (attn @ v).transpose(1, 2).reshape(B_, N, C) - return self.proj(out) - -class SwinTransformerBlock(nn.Module): - def __init__(self, dim, input_resolution, num_heads, window_size=4, shift_size=0): - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.window_size = window_size - self.shift_size = shift_size - - self.norm1 = nn.LayerNorm(dim) - self.attn = WindowAttention(dim, window_size, num_heads) - self.norm2 = nn.LayerNorm(dim) - - self.mlp = nn.Sequential( - nn.Linear(dim, dim * 4), - nn.GELU(), - nn.Linear(dim * 4, dim) - ) - - def forward(self, x): - H, W = self.input_resolution - B, L, C = x.shape - x = x.view(B, H, W, C) - - if self.shift_size > 0: - shifted_x = torch.roll(x, (-self.shift_size, -self.shift_size), (1, 2)) - else: - shifted_x = x - - windows = window_partition(shifted_x, self.window_size) - windows = windows.view(-1, self.window_size * self.window_size, C) - - attn_windows = self.attn(self.norm1(windows)) - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) - - shifted_x = window_reverse(attn_windows, self.window_size, H, W) - - if self.shift_size > 0: - x = torch.roll(shifted_x, (self.shift_size, self.shift_size), (1, 2)) - else: - x = shifted_x - - x = x.view(B, H * W, C) - x = x + self.mlp(self.norm2(x)) - return x - -# ---------- Final Network ---------- - -class SwinTinyNet(nn.Module): - def __init__(self, num_classes=10): - super(SwinTinyNet, self).__init__() - self.patch_embed = PatchEmbed(img_size=32, patch_size=4, in_chans=3, embed_dim=48) - self.block1 = SwinTransformerBlock(dim=48, input_resolution=(8, 8), num_heads=3, window_size=4, shift_size=0) - self.block2 = SwinTransformerBlock(dim=48, input_resolution=(8, 8), num_heads=3, window_size=4, shift_size=2) - self.norm = nn.LayerNorm(48) - self.fc = nn.Linear(48, num_classes) - - def forward(self, x): - x = self.patch_embed(x) - x = self.block1(x) - x = self.block2(x) - x = self.norm(x) - x = x.mean(dim=1) - x = self.fc(x) - return F.log_softmax(x, dim=1) - -# ---------- Training and Testing ---------- - -def train(args, model, device, train_loader, optimizer, epoch): - model.train() - for batch_idx, (data, target) in enumerate(train_loader): - data, target = data.to(device), target.to(device) - optimizer.zero_grad() - output = model(data) - loss = F.nll_loss(output, target) - loss.backward() - optimizer.step() - if batch_idx % args.log_interval == 0: - print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( - epoch, batch_idx * len(data), len(train_loader.dataset), - 100. * batch_idx / len(train_loader), loss.item())) - if args.dry_run: - break - -def test(args, model, device, test_loader): - model.eval() - test_loss = 0 - correct = 0 - with torch.no_grad(): - for data, target in test_loader: - data, target = data.to(device), target.to(device) - output = model(data) - test_loss += F.nll_loss(output, target, reduction='sum').item() - pred = output.argmax(dim=1, keepdim=True) - correct += pred.eq(target.view_as(pred)).sum().item() - if args.dry_run: - break - - test_loss /= len(test_loader.dataset) - print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( - test_loss, correct, len(test_loader.dataset), - 100. * correct / len(test_loader.dataset))) - -# ---------- Main ---------- - -def main(): - parser = argparse.ArgumentParser(description='Swin Transformer CIFAR10 Example') - parser.add_argument('--batch-size', type=int, default=64) - parser.add_argument('--test-batch-size', type=int, default=1000) - parser.add_argument('--epochs', type=int, default=10) - parser.add_argument('--lr', type=float, default=0.01) - parser.add_argument('--gamma', type=float, default=0.7) - parser.add_argument('--dry-run', action='store_true') - parser.add_argument('--seed', type=int, default=42) - parser.add_argument('--log-interval', type=int, default=10) - parser.add_argument('--save-model', action='store_true') - args = parser.parse_args() - - use_accel = torch.accelerator.is_available() - device = torch.accelerator.current_accelerator() if use_accel else torch.device("cpu") - print(f"Using device: {device}") - - torch.manual_seed(args.seed) - - transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.5,), (0.5,)) - ]) - - train_loader = torch.utils.data.DataLoader( - datasets.CIFAR10('../data', train=True, download=True, transform=transform), - batch_size=args.batch_size, shuffle=True) - - test_loader = torch.utils.data.DataLoader( - datasets.CIFAR10('../data', train=False, transform=transform), - batch_size=args.test_batch_size, shuffle=False) - - model = SwinTinyNet().to(device) - optimizer = optim.Adam(model.parameters(), lr=args.lr) - scheduler = StepLR(optimizer, step_size=3, gamma=args.gamma) - - for epoch in range(1, args.epochs + 1): - train(args, model, device, train_loader, optimizer, epoch) - test(args, model, device, test_loader) - scheduler.step() - - if args.save_model: - torch.save(model.state_dict(), "swin_cifar10.pt") -main() \ No newline at end of file