diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b50e183 --- /dev/null +++ b/.gitignore @@ -0,0 +1,164 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +Pipfile.lock + +# poetry +poetry.lock + +# PEP 582 +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# IDE +.idea/ +.vscode/ +*.swp +*.swo +*~ + +# Claude specific +.claude/* + +# OS +.DS_Store +Thumbs.db + +# Model and data files +*.pth +*.pt +*.h5 +*.hdf5 +*.pkl +*.pickle +checkpoints/ +data/ +models/ +weights/ + +# Temporary files +tmp/ +temp/ +*.tmp \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..cb20d2a --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,85 @@ +[tool.poetry] +name = "pytorch-pruning" +version = "0.1.0" +description = "PyTorch model pruning and fine-tuning" +authors = ["Your Name "] +readme = "README.md" +packages = [{include = "*.py"}] + +[tool.poetry.dependencies] +python = "^3.8" +torch = "^2.0.0" +torchvision = "^0.15.0" +opencv-python = "^4.8.0" +numpy = "^1.24.0" + +[tool.poetry.group.dev.dependencies] +pytest = "^7.4.0" +pytest-cov = "^4.1.0" +pytest-mock = "^3.11.1" + +[tool.poetry.scripts] +test = "pytest:main" +tests = "pytest:main" + +[tool.pytest.ini_options] +minversion = "7.0" +testpaths = ["tests"] +python_files = ["test_*.py", "*_test.py"] +python_classes = ["Test*", "*Tests"] +python_functions = ["test_*"] +addopts = [ + "-ra", + "--strict-markers", + "--strict-config", + "--cov=.", + "--cov-branch", + "--cov-report=term-missing:skip-covered", + "--cov-report=html:htmlcov", + "--cov-report=xml:coverage.xml", + "--cov-fail-under=0", + "-vv" +] +markers = [ + "unit: Unit tests", + "integration: Integration tests", + "slow: Slow tests" +] + +[tool.coverage.run] +source = ["."] +omit = [ + "*/tests/*", + "*/test_*", + "*/__pycache__/*", + "*/venv/*", + "*/env/*", + "*/.venv/*", + "setup.py", + "*/conftest.py" +] + +[tool.coverage.report] +precision = 2 +show_missing = true +skip_covered = false +exclude_lines = [ + "pragma: no cover", + "def __repr__", + "raise AssertionError", + "raise NotImplementedError", + "if __name__ == .__main__.:", + "if TYPE_CHECKING:", + "class .*\\bProtocol\\):", + "@(abc\\.)?abstractmethod" +] + +[tool.coverage.html] +directory = "htmlcov" + +[tool.coverage.xml] +output = "coverage.xml" + +[build-system] +requires = ["poetry-core>=1.0.0"] +build-backend = "poetry.core.masonry.api" \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..d23bada --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,136 @@ +import pytest +import torch +import tempfile +import shutil +import os +from pathlib import Path +import numpy as np +from unittest.mock import Mock, patch + +@pytest.fixture +def temp_dir(): + """Create a temporary directory for test files.""" + temp_dir = tempfile.mkdtemp() + yield Path(temp_dir) + shutil.rmtree(temp_dir) + +@pytest.fixture +def mock_model(): + """Create a mock PyTorch model for testing.""" + model = Mock() + model.parameters = Mock(return_value=iter([ + torch.nn.Parameter(torch.randn(3, 3)), + torch.nn.Parameter(torch.randn(10)) + ])) + model.state_dict = Mock(return_value={ + 'conv1.weight': torch.randn(64, 3, 3, 3), + 'conv1.bias': torch.randn(64), + 'fc.weight': torch.randn(10, 512), + 'fc.bias': torch.randn(10) + }) + model.eval = Mock(return_value=model) + model.train = Mock(return_value=model) + return model + +@pytest.fixture +def sample_tensor(): + """Create a sample tensor for testing.""" + return torch.randn(1, 3, 224, 224) + +@pytest.fixture +def sample_batch(): + """Create a sample batch of data for testing.""" + batch_size = 4 + images = torch.randn(batch_size, 3, 224, 224) + labels = torch.randint(0, 2, (batch_size,)) + return images, labels + +@pytest.fixture +def mock_dataset(): + """Create a mock dataset for testing.""" + dataset = Mock() + dataset.__len__ = Mock(return_value=100) + dataset.__getitem__ = Mock(side_effect=lambda idx: ( + torch.randn(3, 224, 224), + torch.randint(0, 2, (1,)).item() + )) + return dataset + +@pytest.fixture +def mock_dataloader(mock_dataset): + """Create a mock dataloader for testing.""" + dataloader = Mock() + dataloader.__iter__ = Mock(return_value=iter([ + (torch.randn(4, 3, 224, 224), torch.randint(0, 2, (4,))) + for _ in range(5) + ])) + dataloader.__len__ = Mock(return_value=5) + dataloader.dataset = mock_dataset + return dataloader + +@pytest.fixture +def device(): + """Get the appropriate device for testing.""" + return torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +@pytest.fixture +def random_seed(): + """Set random seeds for reproducibility.""" + seed = 42 + torch.manual_seed(seed) + np.random.seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + yield seed + +@pytest.fixture +def mock_vgg_model(): + """Create a mock VGG model for testing.""" + with patch('torchvision.models.vgg16') as mock_vgg: + mock_model = Mock() + mock_model.features = torch.nn.Sequential( + torch.nn.Conv2d(3, 64, 3), + torch.nn.ReLU() + ) + mock_model.classifier = torch.nn.Sequential( + torch.nn.Linear(25088, 4096), + torch.nn.ReLU(), + torch.nn.Linear(4096, 2) + ) + mock_vgg.return_value = mock_model + yield mock_model + +@pytest.fixture +def sample_checkpoint(temp_dir): + """Create a sample checkpoint file.""" + checkpoint_path = temp_dir / 'checkpoint.pth' + checkpoint_data = { + 'epoch': 10, + 'model_state_dict': { + 'conv1.weight': torch.randn(64, 3, 3, 3), + 'conv1.bias': torch.randn(64) + }, + 'optimizer_state_dict': {}, + 'loss': 0.5 + } + torch.save(checkpoint_data, checkpoint_path) + return checkpoint_path + +@pytest.fixture(autouse=True) +def cleanup_cuda_memory(): + """Clean up CUDA memory after each test.""" + yield + if torch.cuda.is_available(): + torch.cuda.empty_cache() + +@pytest.fixture +def capture_stdout(): + """Capture stdout for testing print statements.""" + import io + import sys + captured = io.StringIO() + old_stdout = sys.stdout + sys.stdout = captured + yield captured + sys.stdout = old_stdout \ No newline at end of file diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_setup_validation.py b/tests/test_setup_validation.py new file mode 100644 index 0000000..206e294 --- /dev/null +++ b/tests/test_setup_validation.py @@ -0,0 +1,98 @@ +import pytest +import torch +import sys +import os +from pathlib import Path + +class TestSetupValidation: + """Validation tests to ensure the testing infrastructure is properly configured.""" + + def test_pytest_is_installed(self): + """Test that pytest is properly installed.""" + assert pytest.__version__ is not None + + def test_torch_is_available(self): + """Test that PyTorch is properly installed.""" + assert torch.__version__ is not None + + def test_project_structure_exists(self): + """Test that the project structure is correctly set up.""" + project_root = Path(__file__).parent.parent + + assert project_root.exists() + assert (project_root / 'tests').exists() + assert (project_root / 'tests' / 'unit').exists() + assert (project_root / 'tests' / 'integration').exists() + assert (project_root / 'tests' / 'conftest.py').exists() + + def test_conftest_fixtures_available(self, temp_dir, mock_model, sample_tensor): + """Test that conftest fixtures are accessible.""" + assert temp_dir.exists() + assert temp_dir.is_dir() + + assert mock_model is not None + assert hasattr(mock_model, 'parameters') + + assert isinstance(sample_tensor, torch.Tensor) + assert sample_tensor.shape == (1, 3, 224, 224) + + def test_coverage_configuration(self): + """Test that coverage is properly configured.""" + try: + import coverage + assert coverage.__version__ is not None + except ImportError: + pytest.skip("Coverage not yet installed") + + def test_mock_utilities_available(self): + """Test that mocking utilities are available.""" + from unittest.mock import Mock, patch + + mock_obj = Mock() + mock_obj.test_method.return_value = 42 + assert mock_obj.test_method() == 42 + + @pytest.mark.unit + def test_unit_marker_works(self): + """Test that the unit test marker is properly configured.""" + assert True + + @pytest.mark.integration + def test_integration_marker_works(self): + """Test that the integration test marker is properly configured.""" + assert True + + @pytest.mark.slow + def test_slow_marker_works(self): + """Test that the slow test marker is properly configured.""" + assert True + + def test_python_path_includes_project_root(self): + """Test that the project root is in Python path for imports.""" + project_root = str(Path(__file__).parent.parent) + assert any(project_root in path for path in sys.path) + + def test_can_import_project_modules(self): + """Test that project modules can be imported.""" + try: + import dataset + import finetune + import prune + assert True + except ImportError as e: + pytest.fail(f"Failed to import project modules: {e}") + + def test_device_fixture_works(self, device): + """Test that the device fixture returns a valid device.""" + assert isinstance(device, torch.device) + assert device.type in ['cpu', 'cuda'] + + def test_random_seed_fixture_provides_reproducibility(self, random_seed): + """Test that random seed fixture ensures reproducibility.""" + assert random_seed == 42 + + tensor1 = torch.randn(5) + torch.manual_seed(random_seed) + tensor2 = torch.randn(5) + + assert torch.allclose(tensor1, tensor2) \ No newline at end of file diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000..e69de29