Skip to content

Commit 5d443df

Browse files
committed
Merge fix/kwargs-bug: fix CorrelationSplitter kwargs handling, harden tests
2 parents 71eb422 + 3965dc1 commit 5d443df

File tree

7 files changed

+260
-95
lines changed

7 files changed

+260
-95
lines changed

.github/workflows/python-publish.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99
name: Upload Python Package
1010

1111
on:
12-
release:
13-
types: [published]
12+
push:
13+
tags:
14+
- 'v*'
1415

1516
permissions:
1617
contents: read

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,4 +166,6 @@ cython_debug/
166166
.idea/
167167

168168
# VSCode
169-
.vscode/
169+
.vscode/uv.lock
170+
test/tmp/
171+
.claude/

CLAUDE.md

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# CLAUDE.md
2+
3+
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
4+
5+
## Project Overview
6+
7+
VertiBench is a Python library for benchmarking vertical federated learning (VFL). It generates synthetic VFL datasets with tunable feature importance imbalance and inter-party correlation, then evaluates the quality of vertical data partitions along those two dimensions.
8+
9+
## Build & Development Commands
10+
11+
```bash
12+
# Install from source (editable)
13+
pip install -e .
14+
15+
# Install with test dependencies (adds xgboost)
16+
pip install -e ".[test]"
17+
18+
# Build distribution
19+
python -m build
20+
21+
# Run all tests
22+
python -m unittest discover test/
23+
24+
# Run individual test files
25+
python -m unittest test.test_splitter
26+
python -m unittest test.test_evaluator
27+
python -m unittest test.test_evaluate_alpha
28+
29+
# Run a single test case
30+
python -m unittest test.test_splitter.TestImportanceSplitter.test_split_tabular
31+
```
32+
33+
No linter or formatter is configured for this project.
34+
35+
## Architecture
36+
37+
The library lives in `src/vertibench/` and has two core modules:
38+
39+
### Splitter.py — Vertical Data Partitioning
40+
41+
Abstract base class `Splitter` defines the interface: `split_indices()` returns per-party feature index lists, and `split()` applies them to datasets.
42+
43+
Three implementations:
44+
- **ImportanceSplitter** — Uses Dirichlet distribution to assign features to parties with controllable importance imbalance. The `weights` parameter controls expected importance per party (higher weight = more features).
45+
- **CorrelationSplitter** — Uses BRKGA (pymoo) genetic algorithm to find partitions that match a target inter/intra-party correlation ratio. Parameter `beta`[0,1] controls the balance. Requires `fit()` on data before splitting.
46+
- **SimpleSplitter** — Uniform contiguous split of features across parties.
47+
48+
### Evaluator.py — Split Quality Assessment
49+
50+
- **ImportanceEvaluator** — Computes per-party feature importance using SHAP Permutation explainer. `evaluate_alpha()` recovers the Dirichlet concentration parameter from importance scores.
51+
- **CorrelationEvaluator** — Computes correlation matrices and scores inner vs. inter-party correlation. `evaluate_beta()` recovers the correlation concentration metric. Supports GPU acceleration via PyTorch (`gpu_id` parameter). Uses multiple SVD strategies depending on feature count (exact for <100, randomized for larger).
52+
53+
### Key Data Flow
54+
55+
1. Generate data (e.g., `sklearn.datasets.make_classification`)
56+
2. `Splitter.split(X)` → list of per-party feature matrices `Xs`
57+
3. `Evaluator.evaluate(Xs, ...)` → quality scores
58+
4. `evaluate_alpha()` / `evaluate_beta()` → concentration metrics
59+
60+
### Design Patterns
61+
62+
- `Splitter` uses ABC + template method: concrete classes implement `split_indices()`, base class handles `split()` logic.
63+
- `CorrelationSplitter` composes a `CorrelationEvaluator` internally for optimization.
64+
- Correlation computation has multiple backends: Spearman (pandas), Pearson (numpy/torch), with CPU/GPU variants.
65+
66+
## Testing
67+
68+
Tests use `unittest` with `subTest()` for parameterized variants. Test data is generated synthetically via `generate_data()` and `split_data()` helpers in each test file. The evaluator tests train actual XGBoost models, so the `[test]` extras are required.
69+
70+
## Dependencies
71+
72+
Key: numpy, scipy, scikit-learn, torch, shap, pymoo, matplotlib. Python >= 3.9.

HyperParameters.md

Lines changed: 0 additions & 14 deletions
This file was deleted.

src/vertibench/Splitter.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,18 @@ def split_indices(self, *args, **kwargs):
2626
"""
2727
pass
2828

29-
def split(self, *Xs, indices=None, allow_empty_party=False, fill=None):
29+
def split(self, *Xs, indices=None, allow_empty_party=False, fill=None, **kwargs):
3030
assert len(Xs) > 0, "At least one dataset should be given"
31+
n_features = Xs[0].shape[1]
32+
if n_features < self.num_parties:
33+
raise ValueError(
34+
f"Number of features ({n_features}) must be >= number of parties ({self.num_parties})")
3135
ans = []
3236

3337
# calculate the indices for each party for all datasets
3438
if indices is None:
3539
allX = np.concatenate(Xs, axis=0)
36-
party_to_feature = self.split_indices(allX, allow_empty_party=allow_empty_party)
40+
party_to_feature = self.split_indices(allX, allow_empty_party=allow_empty_party, **kwargs)
3741
else:
3842
party_to_feature = indices
3943

@@ -57,12 +61,12 @@ def split(self, *Xs, indices=None, allow_empty_party=False, fill=None):
5761

5862

5963
class ImportanceSplitter(Splitter):
60-
def __init__(self, num_parties, weights=1, seed=None):
64+
def __init__(self, num_parties, weights=1., seed=None):
6165
"""
6266
Split a 2D dataset by feature importance under dirichlet distribution (assuming the features are independent).
6367
:param num_parties: [int] number of parties
64-
:param weights: [int | list with size num_parties]
65-
If weights is an int, the weight of each party is the same.
68+
:param weights: [float | list with size num_parties]
69+
If weights is a float, the weight of each party is the same. Equivalent to an array of [weights]*num_parties.
6670
If weights is an array, the weight of each party is the corresponding element in the array.
6771
The weights indicate the expected sum of feature importance of each party.
6872
Meanwhile, larger weights mean less bias on the feature importance.
@@ -72,8 +76,8 @@ def __init__(self, num_parties, weights=1, seed=None):
7276
self.weights = weights
7377
self.seed = seed
7478
np.random.seed(seed)
75-
if isinstance(self.weights, Real):
76-
self.weights = [self.weights for _ in range(self.num_parties)]
79+
if isinstance(self.weights, Real): # both int & float values pass this 'if'
80+
self.weights = [self.weights for _ in range(self.num_parties)] # a uniform weights array is constructed
7781

7882
self.check_params()
7983

@@ -103,7 +107,7 @@ def dirichlet(alpha):
103107
xs.append(1 - sum(xs))
104108
return np.array(xs)
105109

106-
def split_indices(self, X, allow_empty_party=False):
110+
def split_indices(self, X, allow_empty_party=False, **kwargs):
107111
"""
108112
Split the indices of X by feature importance.
109113
:param allow_empty_party: [bool] whether to allow parties with zero features
@@ -168,7 +172,7 @@ def __init__(self, num_parties: int, evaluator: CorrelationEvaluator = None, see
168172
super().__init__(num_parties)
169173
self.evaluator = evaluator
170174
if evaluator is None:
171-
self.evaluator = CorrelationEvaluator(gpu_id=gpu_id)
175+
self.evaluator = CorrelationEvaluator(gpu_id=gpu_id, n_jobs=n_jobs)
172176
self.seed = seed
173177
self.gpu_id = gpu_id
174178
if self.gpu_id is not None:
@@ -320,7 +324,7 @@ def split_indices(self, X, n_elites=20, n_offsprings=70, n_mutants=10, n_gen=100
320324
self.best_icor = res_beta.opt.get('icor')[0]
321325
self.best_error = res_beta.F[0]
322326
# print(f"Best permutation order: {permute_order}")
323-
# print(f"Beta {self.beta}, Best match icor: {best_match_icor}")
327+
# print(f"Beta {beta}, Best match icor: {self.best_icor}")
324328

325329
# summarize the feature ids on each party
326330
party_cut_points = np.cumsum(self.evaluator.n_features_on_party)
@@ -333,9 +337,17 @@ def split_indices(self, X, n_elites=20, n_offsprings=70, n_mutants=10, n_gen=100
333337
assert (np.sort(np.concatenate(self.best_feature_per_party)) == np.arange(X.shape[1])).all()
334338
return self.best_feature_per_party
335339

336-
def fit_split(self, X, **kwargs):
337-
self.fit(X, **kwargs)
338-
return self.split(X, **kwargs)
340+
def fit_split(self, X, beta=0.5, **fit_kwargs):
341+
"""
342+
Fit the splitter and split the data.
343+
:param X: [np.ndarray] 2D dataset
344+
:param beta: [float] the tightness of inner-party correlation (passed to split_indices, not fit)
345+
:param fit_kwargs: additional keyword arguments passed to fit() (BRKGA parameters for fit_min_max)
346+
"""
347+
if not (0 <= beta <= 1):
348+
raise ValueError(f"beta should be in [0, 1], got {beta}")
349+
self.fit(X, **fit_kwargs)
350+
return self.split(X, beta=beta)
339351

340352
def visualize(self, *args, **kwargs):
341353
return self.evaluator.visualize(*args, **kwargs)

0 commit comments

Comments
 (0)