Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 164 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
Pipfile.lock

# poetry
poetry.lock

# PEP 582
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# IDE
.idea/
.vscode/
*.swp
*.swo
*~

# Claude specific
.claude/*

# OS
.DS_Store
Thumbs.db

# Model and data files
*.pth
*.pt
*.h5
*.hdf5
*.pkl
*.pickle
checkpoints/
data/
models/
weights/

# Temporary files
tmp/
temp/
*.tmp
85 changes: 85 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
[tool.poetry]
name = "pytorch-pruning"
version = "0.1.0"
description = "PyTorch model pruning and fine-tuning"
authors = ["Your Name <[email protected]>"]
readme = "README.md"
packages = [{include = "*.py"}]

[tool.poetry.dependencies]
python = "^3.8"
torch = "^2.0.0"
torchvision = "^0.15.0"
opencv-python = "^4.8.0"
numpy = "^1.24.0"

[tool.poetry.group.dev.dependencies]
pytest = "^7.4.0"
pytest-cov = "^4.1.0"
pytest-mock = "^3.11.1"

[tool.poetry.scripts]
test = "pytest:main"
tests = "pytest:main"

[tool.pytest.ini_options]
minversion = "7.0"
testpaths = ["tests"]
python_files = ["test_*.py", "*_test.py"]
python_classes = ["Test*", "*Tests"]
python_functions = ["test_*"]
addopts = [
"-ra",
"--strict-markers",
"--strict-config",
"--cov=.",
"--cov-branch",
"--cov-report=term-missing:skip-covered",
"--cov-report=html:htmlcov",
"--cov-report=xml:coverage.xml",
"--cov-fail-under=0",
"-vv"
]
markers = [
"unit: Unit tests",
"integration: Integration tests",
"slow: Slow tests"
]

[tool.coverage.run]
source = ["."]
omit = [
"*/tests/*",
"*/test_*",
"*/__pycache__/*",
"*/venv/*",
"*/env/*",
"*/.venv/*",
"setup.py",
"*/conftest.py"
]

[tool.coverage.report]
precision = 2
show_missing = true
skip_covered = false
exclude_lines = [
"pragma: no cover",
"def __repr__",
"raise AssertionError",
"raise NotImplementedError",
"if __name__ == .__main__.:",
"if TYPE_CHECKING:",
"class .*\\bProtocol\\):",
"@(abc\\.)?abstractmethod"
]

[tool.coverage.html]
directory = "htmlcov"

[tool.coverage.xml]
output = "coverage.xml"

[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
Empty file added tests/__init__.py
Empty file.
136 changes: 136 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import pytest
import torch
import tempfile
import shutil
import os
from pathlib import Path
import numpy as np
from unittest.mock import Mock, patch

@pytest.fixture
def temp_dir():
"""Create a temporary directory for test files."""
temp_dir = tempfile.mkdtemp()
yield Path(temp_dir)
shutil.rmtree(temp_dir)

@pytest.fixture
def mock_model():
"""Create a mock PyTorch model for testing."""
model = Mock()
model.parameters = Mock(return_value=iter([
torch.nn.Parameter(torch.randn(3, 3)),
torch.nn.Parameter(torch.randn(10))
]))
model.state_dict = Mock(return_value={
'conv1.weight': torch.randn(64, 3, 3, 3),
'conv1.bias': torch.randn(64),
'fc.weight': torch.randn(10, 512),
'fc.bias': torch.randn(10)
})
model.eval = Mock(return_value=model)
model.train = Mock(return_value=model)
return model

@pytest.fixture
def sample_tensor():
"""Create a sample tensor for testing."""
return torch.randn(1, 3, 224, 224)

@pytest.fixture
def sample_batch():
"""Create a sample batch of data for testing."""
batch_size = 4
images = torch.randn(batch_size, 3, 224, 224)
labels = torch.randint(0, 2, (batch_size,))
return images, labels

@pytest.fixture
def mock_dataset():
"""Create a mock dataset for testing."""
dataset = Mock()
dataset.__len__ = Mock(return_value=100)
dataset.__getitem__ = Mock(side_effect=lambda idx: (
torch.randn(3, 224, 224),
torch.randint(0, 2, (1,)).item()
))
return dataset

@pytest.fixture
def mock_dataloader(mock_dataset):
"""Create a mock dataloader for testing."""
dataloader = Mock()
dataloader.__iter__ = Mock(return_value=iter([
(torch.randn(4, 3, 224, 224), torch.randint(0, 2, (4,)))
for _ in range(5)
]))
dataloader.__len__ = Mock(return_value=5)
dataloader.dataset = mock_dataset
return dataloader

@pytest.fixture
def device():
"""Get the appropriate device for testing."""
return torch.device('cuda' if torch.cuda.is_available() else 'cpu')

@pytest.fixture
def random_seed():
"""Set random seeds for reproducibility."""
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
yield seed

@pytest.fixture
def mock_vgg_model():
"""Create a mock VGG model for testing."""
with patch('torchvision.models.vgg16') as mock_vgg:
mock_model = Mock()
mock_model.features = torch.nn.Sequential(
torch.nn.Conv2d(3, 64, 3),
torch.nn.ReLU()
)
mock_model.classifier = torch.nn.Sequential(
torch.nn.Linear(25088, 4096),
torch.nn.ReLU(),
torch.nn.Linear(4096, 2)
)
mock_vgg.return_value = mock_model
yield mock_model

@pytest.fixture
def sample_checkpoint(temp_dir):
"""Create a sample checkpoint file."""
checkpoint_path = temp_dir / 'checkpoint.pth'
checkpoint_data = {
'epoch': 10,
'model_state_dict': {
'conv1.weight': torch.randn(64, 3, 3, 3),
'conv1.bias': torch.randn(64)
},
'optimizer_state_dict': {},
'loss': 0.5
}
torch.save(checkpoint_data, checkpoint_path)
return checkpoint_path

@pytest.fixture(autouse=True)
def cleanup_cuda_memory():
"""Clean up CUDA memory after each test."""
yield
if torch.cuda.is_available():
torch.cuda.empty_cache()

@pytest.fixture
def capture_stdout():
"""Capture stdout for testing print statements."""
import io
import sys
captured = io.StringIO()
old_stdout = sys.stdout
sys.stdout = captured
yield captured
sys.stdout = old_stdout
Empty file added tests/integration/__init__.py
Empty file.
Loading