Skip to content

Commit 71eb422

Browse files
authored
Merge pull request #1 from abeludc93/update-dependencies-and-tests
Update dependencies and tests
2 parents 22ac568 + d029c33 commit 71eb422

File tree

10 files changed

+24
-35
lines changed

10 files changed

+24
-35
lines changed

example/SplitEvaluateExample.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,32 @@
1+
#%% Generate a large dataset
12
from sklearn.datasets import make_classification
2-
3-
# Generate a large dataset
43
X, y = make_classification(n_samples=10000, n_features=10)
54

6-
5+
#%% Imports VertiBench
76
from vertibench.Evaluator import ImportanceEvaluator, CorrelationEvaluator
87
from vertibench.Splitter import ImportanceSplitter, CorrelationSplitter
98
from sklearn.linear_model import LogisticRegression
109

11-
# Split by importance
10+
#%% Split by importance
1211
imp_splitter = ImportanceSplitter(num_parties=4, weights=[1, 1, 1, 3])
1312
Xs = imp_splitter.split(X)
1413

15-
# Evaluate split by importance
14+
#%% Evaluate split by importance
1615
model = LogisticRegression()
1716
model.fit(X, y)
1817
imp_evaluator = ImportanceEvaluator()
1918
imp_scores = imp_evaluator.evaluate(Xs, model.predict)
2019
alpha = imp_evaluator.evaluate_alpha(scores=imp_scores)
2120
print(f"Importance scores: {imp_scores}, alpha: {alpha}")
2221

23-
# Split by correlation
22+
#%% Split by correlation
2423
corr_splitter = CorrelationSplitter(num_parties=4)
2524
Xs = corr_splitter.fit_split(X)
2625

27-
# Evaluate split by correlation
26+
#%% Evaluate split by correlation
2827
corr_evaluator = CorrelationEvaluator()
2928
corr_scores = corr_evaluator.fit_evaluate(Xs)
3029
beta = corr_evaluator.evaluate_beta()
3130
print(f"Correlation scores: {corr_scores}, beta: {beta}")
3231

32+
#%%

pyproject.toml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,13 @@ urls = { "Homepage" = "https://github.com/Xtra-Computing/VertiBench" }
3131
# Define your console scripts under [project.scripts] if any in the future
3232

3333
dependencies = [
34-
"matplotlib>=3.7",
35-
"numpy==1.24",
36-
"pymoo==0.6.1.1",
37-
"scikit-learn>=1.2",
38-
"scipy>=1.10",
39-
"shap==0.43",
40-
"torch>=2.0",
34+
"matplotlib>=3.7,<4",
35+
"numpy>=2.1,<3",
36+
"pymoo>=0.6.1,<0.6.1.6",
37+
"scikit-learn>=1.6,<2",
38+
"scipy>=1.10,<2",
39+
"shap>=0.50.0,<1",
40+
"torch>=2.0,<3",
4141
]
4242

4343
[project.optional-dependencies]
-254 Bytes
Binary file not shown.
-5.66 KB
Binary file not shown.
-3.23 KB
Binary file not shown.

src/vertibench/Evaluator.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
1-
import warnings
2-
from typing import Iterable
3-
import time
41
import os
5-
62
import numpy as np
73
import pandas as pd
84
import torch
95
import torch.linalg
10-
from scipy.stats import spearmanr, hmean, gmean
6+
from scipy.stats import spearmanr
117
from sklearn.utils.extmath import randomized_svd
128
import shap
139
import matplotlib.pyplot as plt
@@ -17,7 +13,6 @@
1713
from pymoo.core.duplicate import ElementwiseDuplicateElimination
1814
from pymoo.core.problem import ElementwiseProblem
1915
from pymoo.optimize import minimize
20-
from pymoo.termination.default import DefaultSingleObjectiveTermination
2116
from pymoo.core.problem import StarmapParallelization
2217
from multiprocessing.pool import ThreadPool
2318

src/vertibench/Splitter.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,15 @@
11
from numbers import Real
22
import warnings
33
import abc
4-
54
import numpy as np
65
import torch
7-
import torch.linalg
86
from pymoo.algorithms.soo.nonconvex.brkga import BRKGA
97
from pymoo.core.duplicate import ElementwiseDuplicateElimination
108
from pymoo.core.problem import ElementwiseProblem
119
from pymoo.optimize import minimize
1210
from pymoo.termination.default import DefaultSingleObjectiveTermination
1311
from pymoo.core.problem import StarmapParallelization
1412
from multiprocessing.pool import ThreadPool
15-
1613
from .Evaluator import CorrelationEvaluator
1714

1815

test/test_evaluate_alpha.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
11
import unittest
2-
3-
from collections import defaultdict
42
import numpy as np
53
import xgboost as xgb
64
from sklearn.datasets import make_classification
75
from sklearn.linear_model import LinearRegression
8-
9-
from src.vertibench.Evaluator import ImportanceEvaluator, CorrelationEvaluator
6+
from vertibench.Evaluator import ImportanceEvaluator
107

118

129
class TestAlphaEvaluator(unittest.TestCase):
@@ -58,3 +55,7 @@ def test_evaluate_alpha(self):
5855
lr = LinearRegression()
5956
lr.fit(np.array(split_ratios).reshape(-1, 1), np.array(alpha1s).reshape(-1, 1))
6057
self.assertGreater(lr.coef_[0][0], 0)
58+
59+
60+
if __name__ == '__main__':
61+
unittest.main()

test/test_evaluator.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
import unittest
2-
32
from collections import defaultdict
43
import numpy as np
54
import xgboost as xgb
65
from sklearn.datasets import make_classification
7-
8-
from src.vertibench.Evaluator import ImportanceEvaluator, CorrelationEvaluator
6+
from vertibench.Evaluator import ImportanceEvaluator, CorrelationEvaluator
97

108

119
def generate_data():

test/test_splitter.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
11
import unittest
2-
import random
32
from itertools import product
43
from collections import defaultdict
54
from sklearn.datasets import make_classification
6-
75
import numpy as np
86
from scipy.stats import spearmanr
97
import xgboost as xgb
10-
from src.vertibench.Splitter import ImportanceSplitter, CorrelationSplitter, SimpleSplitter
11-
from src.vertibench.Evaluator import ImportanceEvaluator, CorrelationEvaluator
8+
from vertibench.Splitter import ImportanceSplitter, CorrelationSplitter
129

1310

1411
def generate_data():
@@ -263,4 +260,5 @@ def test_fit_split(self):
263260
self.assertAlmostEqual(beta, beta_eval, delta=0.1)
264261

265262

266-
263+
if __name__ == '__main__':
264+
unittest.main()

0 commit comments

Comments
 (0)