Skip to content

Commit 3d6f1b2

Browse files
committed
packaging1
1 parent 7960ad5 commit 3d6f1b2

File tree

5 files changed

+86
-27
lines changed

5 files changed

+86
-27
lines changed

AROS/main.py renamed to main.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11

2-
!pip install -r requirements.txt
2+
import aros
33
import argparse
44
import torch
55
import torch.nn as nn
6-
from evaluate import *
7-
from utils import *
6+
from aros.evaluate import *
7+
from aros.utils import *
88
from tqdm.notebook import tqdm
9-
from data_loader import *
10-
from stability_loss_function import *
9+
from aros.data_loader import *
10+
from aros.stability_loss_function import *
1111

1212
def main():
1313
parser = argparse.ArgumentParser(description="Hyperparameters for the script")

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[build-system]
2+
requires = ["setuptools>=42", "wheel"]
3+
build-backend = "setuptools.build_meta"

requirements.txt

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
11
geotorch
2+
torch
23
torchdiffeq
3-
git+https://github.com/RobustBench/robustbench.git
4-
timm==1.0.9
4+
timm==1.0.9
5+
robustbench
6+
numpy
7+
scikit-learn
8+
scipy
9+
tqdm

setup.cfg

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,36 +6,18 @@ author_email = [email protected]
66
description = AROS: Adversarially Robust Out-of-Distribution Detection through Stability
77
long_description = file: README.md
88
long_description_content_type = text/markdown
9-
license_files = LICENSE.md
10-
license_file_type = text/markdown
119
url = https://github.com/AdaptiveMotorControlLab/AROS
12-
project_urls =
13-
Bug Tracker = https://github.com/AdaptiveMotorControlLab/AROS/issues
14-
classifiers =
15-
Development Status :: 4 - Beta
16-
Environment :: GPU :: NVIDIA CUDA
17-
Intended Audience :: Science/Research
18-
Operating System :: OS Independent
19-
Programming Language :: Python :: 3
20-
Topic :: Scientific/Engineering :: Artificial Intelligence
21-
License :: OSI Approved :: Apache Software License
2210

2311
[options]
2412
packages = find:
2513
include_package_data = True
2614
python_requires = >=3.10
27-
install_requires =
28-
geotorch
29-
torchdiffeq
30-
git+https://github.com/RobustBench/robustbench.git
15+
install_requires = file: requirements.txt
3116

3217
[options.extras_require]
3318
dev =
3419
pylint
3520
toml
3621
yapf
3722
black
38-
pytest
39-
40-
[bdist_wheel]
41-
universal=0
23+
pytest

tests/test_dataloaders.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import pytest
2+
import torch
3+
from torch.utils.data import DataLoader, Subset
4+
from torchvision.datasets import CIFAR10, CIFAR100
5+
from torchvision.transforms import ToTensor
6+
from aros import (
7+
LabelChangedDataset,
8+
get_subsampled_subset,
9+
get_loaders,
10+
)
11+
12+
# Set up transformations and datasets for tests
13+
transform_tensor = ToTensor()
14+
15+
@pytest.fixture
16+
def cifar10_datasets():
17+
trainset = CIFAR10(root='./data', train=True, download=True, transform=transform_tensor)
18+
testset = CIFAR10(root='./data', train=False, download=True, transform=transform_tensor)
19+
return trainset, testset
20+
21+
@pytest.fixture
22+
def cifar100_datasets():
23+
trainset = CIFAR100(root='./data', train=True, download=True, transform=transform_tensor)
24+
testset = CIFAR100(root='./data', train=False, download=True, transform=transform_tensor)
25+
return trainset, testset
26+
27+
def test_label_changed_dataset(cifar10_datasets):
28+
_, testset = cifar10_datasets
29+
new_label = 99
30+
relabeled_dataset = LabelChangedDataset(testset, new_label)
31+
32+
assert len(relabeled_dataset) == len(testset), "Relabeled dataset should match the original dataset length"
33+
34+
for img, label in relabeled_dataset:
35+
assert label == new_label, "All labels should be changed to the new label"
36+
37+
def test_get_subsampled_subset(cifar10_datasets):
38+
trainset, _ = cifar10_datasets
39+
subset_ratio = 0.1
40+
subset = get_subsampled_subset(trainset, subset_ratio=subset_ratio)
41+
42+
expected_size = int(len(trainset) * subset_ratio)
43+
assert len(subset) == expected_size, f"Subset size should be {expected_size}"
44+
45+
def test_get_loaders_cifar10(cifar10_datasets):
46+
train_loader, test_loader, test_set, test_loader_vs_other = get_loaders('cifar10')
47+
48+
assert isinstance(train_loader, DataLoader)
49+
assert isinstance(test_loader, DataLoader)
50+
assert isinstance(test_loader_vs_other, DataLoader)
51+
52+
for images, labels in test_loader:
53+
assert images.shape[0] == 16, "Test loader batch size should be 16"
54+
break
55+
56+
def test_get_loaders_cifar100(cifar100_datasets):
57+
train_loader, test_loader, test_set, test_loader_vs_other = get_loaders('cifar100')
58+
59+
assert isinstance(train_loader, DataLoader)
60+
assert isinstance(test_loader, DataLoader)
61+
assert isinstance(test_loader_vs_other, DataLoader)
62+
63+
for images, labels in test_loader:
64+
assert images.shape[0] == 16, "Test loader batch size should be 16"
65+
break
66+
67+
def test_get_loaders_invalid_dataset():
68+
with pytest.raises(ValueError, match="Dataset 'invalid_dataset' is not supported."):
69+
get_loaders('invalid_dataset')

0 commit comments

Comments
 (0)