Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 158 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
Pipfile.lock

# PEP 582
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
virtualenv/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# IDE files
.idea/
.vscode/
*.swp
*.swo
*~

# OS files
.DS_Store
.DS_Store?
._*
.Spotlight-V100
.Trashes
ehthumbs.db
Thumbs.db

# Claude settings
.claude/*

# Poetry
# Note: poetry.lock should be committed for applications
# poetry.lock

# UV
# Note: uv.lock should be committed for applications
# uv.lock
1,214 changes: 1,214 additions & 0 deletions poetry.lock

Large diffs are not rendered by default.

96 changes: 96 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
[tool.poetry]
name = "vit-transformer-explainability"
version = "0.1.0"
description = "Vision Transformer (ViT) visualization and explanation tools"
authors = ["Your Name <[email protected]>"]
readme = "Readme.md"
license = "MIT"
packages = [{include = "*.py"}]

[tool.poetry.dependencies]
python = "^3.8"
torch = "^2.0.0"
torchvision = "^0.15.0"
Pillow = "^10.0.0"
numpy = "^1.24.0"
opencv-python = "^4.8.0"

[tool.poetry.group.dev.dependencies]
pytest = "^7.4.0"
pytest-cov = "^4.1.0"
pytest-mock = "^3.11.0"

[tool.poetry.scripts]
test = "pytest:main"
tests = "pytest:main"

[tool.pytest.ini_options]
minversion = "7.0"
addopts = [
"-ra",
"--strict-markers",
"--import-mode=importlib",
"--cov=.",
"--cov-report=term-missing:skip-covered",
"--cov-report=html",
"--cov-report=xml",
"--cov-fail-under=80",
]
python_files = ["test_*.py", "*_test.py"]
python_classes = ["Test*"]
python_functions = ["test_*"]
testpaths = ["tests"]
markers = [
"unit: marks tests as unit tests (fast, isolated)",
"integration: marks tests as integration tests (may be slower)",
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
]

[tool.coverage.run]
source = ["."]
omit = [
"*/tests/*",
"*/test_*",
"*/__pycache__/*",
"*/venv/*",
"*/env/*",
"*/.venv/*",
"*/.env/*",
"*/virtualenv/*",
"*/node_modules/*",
"*/migrations/*",
"*/examples/*",
"*/build/*",
"*/dist/*",
"*/.pytest_cache/*",
"*/.coverage",
"*/htmlcov/*",
"*/site-packages/*",
"config*.py",
]

[tool.coverage.report]
precision = 2
show_missing = true
skip_covered = true
fail_under = 80
exclude_lines = [
"pragma: no cover",
"def __repr__",
"if __name__ == .__main__.:",
"raise AssertionError",
"raise NotImplementedError",
"if TYPE_CHECKING:",
"pass",
"@abstractmethod",
]

[tool.coverage.html]
directory = "htmlcov"

[tool.coverage.xml]
output = "coverage.xml"

[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
Empty file added tests/__init__.py
Empty file.
108 changes: 108 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import pytest
import tempfile
import os
import shutil
from pathlib import Path
import numpy as np
from PIL import Image
import torch


@pytest.fixture
def temp_dir():
"""Create a temporary directory that is cleaned up after the test."""
temp_path = tempfile.mkdtemp()
yield Path(temp_path)
shutil.rmtree(temp_path)


@pytest.fixture
def sample_image(temp_dir):
"""Create a sample RGB image for testing."""
img_array = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
img = Image.fromarray(img_array, 'RGB')
img_path = temp_dir / "test_image.jpg"
img.save(img_path)
return img_path


@pytest.fixture
def sample_tensor():
"""Create a sample tensor for testing."""
return torch.randn(1, 3, 224, 224)


@pytest.fixture
def mock_model():
"""Create a mock model for testing."""
class MockModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.blocks = torch.nn.ModuleList([
torch.nn.Identity() for _ in range(12)
])

def forward(self, x):
return x

return MockModel()


@pytest.fixture
def attention_rollout_config():
"""Configuration for attention rollout testing."""
return {
'head_fusion': 'mean',
'discard_ratio': 0.9,
'class_token': True
}


@pytest.fixture
def grad_rollout_config():
"""Configuration for gradient rollout testing."""
return {
'use_cuda': False,
'category_index': None,
'head_fusion': 'max',
'discard_ratio': 0.9
}


@pytest.fixture
def mock_attention_weights():
"""Create mock attention weights for testing."""
batch_size = 1
num_heads = 3
seq_len = 197 # 196 patches + 1 class token
return torch.rand(batch_size, num_heads, seq_len, seq_len)


@pytest.fixture
def mock_gradients():
"""Create mock gradients for testing."""
return torch.randn(1, 197, 192) # batch_size, seq_len, hidden_dim


@pytest.fixture
def output_dir(temp_dir):
"""Create an output directory for test results."""
output_path = temp_dir / "output"
output_path.mkdir(exist_ok=True)
return output_path


@pytest.fixture(autouse=True)
def reset_torch_hub_dir(monkeypatch):
"""Reset torch hub directory to avoid downloading models during tests."""
monkeypatch.setenv('TORCH_HOME', '/tmp/torch_hub_test')


@pytest.fixture
def mock_transforms():
"""Mock torchvision transforms."""
def transform(x):
if isinstance(x, Image.Image):
return torch.randn(3, 224, 224)
return x
return transform
Empty file added tests/integration/__init__.py
Empty file.
Loading