Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
bd0855c
Initial work on SparK
johnsutor Dec 5, 2024
42bd849
feat(spark): More work on spark
johnsutor Dec 17, 2024
7e73b83
Merge branch 'spark' of github.com:johnsutor/lightly
gabrielfruet Feb 5, 2026
724819c
refactor: should not nest modules.
gabrielfruet Feb 5, 2026
b0023f0
refactor: removed empty file
gabrielfruet Feb 5, 2026
8120990
refactor: adhered to correct directory structure.
gabrielfruet Feb 5, 2026
a7138d0
feat: put everything into the sparse spark module.
gabrielfruet Feb 5, 2026
9b869e9
refactor: removed redundant super calls with class.
gabrielfruet Feb 5, 2026
36cb94c
refactor: removed spark code.
gabrielfruet Feb 5, 2026
0afc280
refactor: removing redundant super calls
gabrielfruet Feb 5, 2026
280a479
refactor: remove empty file
gabrielfruet Feb 6, 2026
7e0f2bc
refactor: porting original code. starting from scratch
gabrielfruet Feb 6, 2026
9a9dff6
refactor: fixing type hint problems.
gabrielfruet Feb 6, 2026
407a4c0
refactor: removing unecessary redundant super
gabrielfruet Feb 6, 2026
47eecf3
fix: indentation
gabrielfruet Feb 6, 2026
7a956b3
feat: working module
gabrielfruet Feb 6, 2026
9d0f5c4
refactor: using library already implemente masking
gabrielfruet Feb 9, 2026
4aa69f7
feat: using patchify
gabrielfruet Feb 9, 2026
8b95a51
refactor: putting densification into a single module
gabrielfruet Feb 9, 2026
014f724
typo: raito -> ratio
gabrielfruet Feb 9, 2026
de01b0a
feat: encapsulated logic to single dnesifier module
gabrielfruet Feb 9, 2026
445bda7
refactor: cleaning code.
gabrielfruet Feb 9, 2026
741d531
refactor: letting sparse encoder be repsonsible for sizes and etc
gabrielfruet Feb 9, 2026
fb5c90d
feat: resnet18
gabrielfruet Feb 9, 2026
deeb4ef
refactor: removing unused code
gabrielfruet Feb 9, 2026
64df2e3
fix: bool tensor is inconvenient
gabrielfruet Feb 9, 2026
45b6eb8
refactor: documenting
gabrielfruet Feb 9, 2026
fb7903a
refactor: masking as a module
gabrielfruet Feb 9, 2026
138380f
refactor: removing unused variables
gabrielfruet Feb 9, 2026
6b3dbdf
refactor: removing unecessary module dependency
gabrielfruet Feb 9, 2026
ed8691a
refactor: loss as module
gabrielfruet Feb 9, 2026
3dcd5c3
refactor: spark visualization decoding logic as module
gabrielfruet Feb 9, 2026
1c316da
refactor: remove unused
gabrielfruet Feb 9, 2026
3ba1839
refactor
gabrielfruet Feb 9, 2026
70d32dd
refactor: removed big module and refactored timm funcs
gabrielfruet Feb 9, 2026
2f73d15
feat: example script
gabrielfruet Feb 9, 2026
b48d4eb
refactor: removed unused code and added opyrights
gabrielfruet Feb 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 176 additions & 0 deletions examples/pytorch_lightning/spark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
# This example requires the following dependencies to be installed:
# pip install lightly

# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.

import pytorch_lightning as pl
import timm
import torch
import torchvision
from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar
from torch import nn
from torchvision.transforms import v2

## The global projection head is the same as the Barlow Twins one
from lightly.models.modules.sparse_spark import (
LightDecoder,
SparKDensifier,
SparKMasker,
SparKMaskingOuptut,
SparKOutputDecoder,
SparKPatchReconLoss,
SparseEncoder,
)
from lightly.models.utils import patchify


def get_downsample_ratio_from_timm_model(model: nn.Module) -> int:
return model.feature_info[-1]["reduction"]


def get_enc_feat_map_chs_from_timm_model(model: nn.Module) -> list[int]:
return [fi["num_chs"] for fi in model.feature_info]


class SparseSparK(pl.LightningModule):
def __init__(
self,
input_size: int = 416,
mask_ratio: float = 0.6,
densify_norm: str = "bn",
sbn=False,
):
super().__init__()
backbone = timm.create_model(
"resnet18", drop_path_rate=0.05, features_only=True
)
self.sparse_encoder = SparseEncoder(
backbone,
downsample_ratio=get_downsample_ratio_from_timm_model(backbone),
feature_map_channels=get_enc_feat_map_chs_from_timm_model(backbone),
input_size=input_size,
sbn=sbn,
verbose=True,
)
self.dense_decoder = LightDecoder(
self.sparse_encoder.downsample_ratio,
width=self.sparse_encoder.enc_feat_map_chs[-1],
)
self.masker = SparKMasker(
feature_map_size=(self.sparse_encoder.fmap_h, self.sparse_encoder.fmap_w),
downsample_ratio=self.sparse_encoder.downsample_ratio,
mask_ratio=mask_ratio,
)
self.densifier = SparKDensifier(
encoder_in_channels=self.sparse_encoder.enc_feat_map_chs,
decoder_in_channel=self.dense_decoder.width,
densify_norm_str=densify_norm.lower(),
sbn=sbn,
)
self.downsample_ratio = self.sparse_encoder.downsample_ratio
# loss module for patch reconstruction
self.recon_loss_fn = SparKPatchReconLoss()
# output decoder for visualization (pass minimal spatial props)
self.output_decoder = SparKOutputDecoder(
self.sparse_encoder.fmap_h,
self.sparse_encoder.fmap_w,
self.sparse_encoder.downsample_ratio,
)

def forward(
self,
inp_bchw: torch.Tensor,
vis=False,
):
# step1. Mask
mask_out: SparKMaskingOuptut = self.masker(inp_bchw)
masked_bchw, per_level_mask = mask_out
active_b1fHfW = per_level_mask[0]
active_b1hw = per_level_mask[-1]
# step2. Encode: get hierarchical encoded sparse features (a list containing 4 feature maps at 4 scales)
fea_bcffs: list[torch.Tensor] = self.sparse_encoder(masked_bchw)
# step3. Densify: get hierarchical dense features for decoding
to_dec = self.densifier(fea_bcffs)
# step4. Decode and reconstruct
rec_bchw = self.dense_decoder(to_dec)
inp, rec = (
patchify(inp_bchw, self.downsample_ratio),
patchify(rec_bchw, self.downsample_ratio),
) # inp and rec: (B, L = f*f, N = C*downsample_raito**2)

recon_loss, mean, var = self.recon_loss_fn(inp, rec, active_b1fHfW)

if vis:
return self.output_decoder(rec, mean, var, inp_bchw, active_b1hw)
else:
return recon_loss

def training_step(self, batch, batch_index) -> torch.Tensor:
img, target = batch
recon_loss = self.forward(img)
# Log the training loss to logger and progress bar (per-step and per-epoch)
self.log(
"train_loss",
recon_loss,
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
)
return recon_loss

def configure_optimizers(self):
return torch.optim.SGD(
self.parameters(), lr=0.03, momentum=0.9, weight_decay=1e-4
)


model = SparseSparK(input_size=416)


# we ignore object detection annotations by setting target_transform to return 0
def target_transform(t):
return 0


dataset = torchvision.datasets.Caltech101(
"datasets/caltech101",
download=True,
transform=v2.Compose(
[
v2.Resize((416, 416)),
v2.RGB(),
v2.ToTensor(),
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
),
target_transform=target_transform,
)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")

dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=4,
shuffle=True,
drop_last=True,
num_workers=8,
)


accelerator = "gpu" if torch.cuda.is_available() else "cpu"

trainer = pl.Trainer(
max_epochs=30,
devices=1,
accelerator=accelerator,
callbacks=[
RichProgressBar(),
],
)
trainer.fit(
model=model,
train_dataloaders=dataloader,
)
Loading