Skip to content

Commit e0f26ad

Browse files
committed
Check membership with dicts; better default handling for None sentinels
1 parent 8355612 commit e0f26ad

File tree

11 files changed

+87
-121
lines changed

11 files changed

+87
-121
lines changed

manify/curvature_estimation/_pipelines.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,8 @@ def distortion_pipeline(
3535
A function f(signature) → loss, where signature is a list
3636
of (curvature, dim) tuples.
3737
"""
38-
if embedder_init_kwargs is None:
39-
embedder_init_kwargs = {}
40-
if embedder_fit_kwargs is None:
41-
embedder_fit_kwargs = {}
38+
embedder_init_kwargs = embedder_init_kwargs or {}
39+
embedder_fit_kwargs = embedder_fit_kwargs or {}
4240

4341
dists = dists.to(pm.device)
4442
dists_rescaled = dists / dists.max()
@@ -82,14 +80,10 @@ def classifier_pipeline(
8280
Returns:
8381
The loss of the classifier on the test set after embedding the distances using the product manifold.
8482
"""
85-
if embedder_init_kwargs is None:
86-
embedder_init_kwargs = {}
87-
if embedder_fit_kwargs is None:
88-
embedder_fit_kwargs = {}
89-
if model_init_kwargs is None:
90-
model_init_kwargs = {}
91-
if model_fit_kwargs is None:
92-
model_fit_kwargs = {}
83+
embedder_init_kwargs = embedder_init_kwargs or {}
84+
embedder_fit_kwargs = embedder_fit_kwargs or {}
85+
model_init_kwargs = model_init_kwargs or {}
86+
model_fit_kwargs = model_fit_kwargs or {}
9387

9488
dists = dists.to(pm.device)
9589
dists_rescaled = dists / dists.max()

manify/embedders/_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class BaseEmbedder(BaseEstimator, TransformerMixin, ABC):
3333
def __init__(self, pm: ProductManifold, random_state: int | None = None, device: str | None = None) -> None:
3434
self.pm = pm
3535
self.random_state = random_state
36-
self.device = pm.device if device is None else device
36+
self.device = device or pm.device
3737
self.loss_history_: dict[str, list[float]] = {}
3838
self.is_fitted_: bool = False
3939

manify/manifolds.py

Lines changed: 23 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -224,14 +224,11 @@ def sample(
224224
x: Tensor of sampled points on the manifold
225225
v: Tensor of tangent vectors
226226
"""
227-
if z_mean is None:
228-
z_mean = self.mu0
227+
z_mean = self.mu0 if z_mean is None else z_mean
229228
z_mean = torch.Tensor(z_mean).reshape(-1, self.ambient_dim).to(self.device)
230229
n = z_mean.shape[0]
231-
if sigma is None:
232-
sigma = torch.stack([torch.eye(self.dim)] * n).to(self.device)
233-
else:
234-
sigma = torch.Tensor(sigma).reshape(-1, self.dim, self.dim).to(self.device)
230+
sigma = torch.stack([torch.eye(self.dim)] * n).to(self.device) if sigma is None else sigma
231+
sigma = torch.Tensor(sigma).reshape(-1, self.dim, self.dim).to(self.device)
235232
assert sigma.shape == (
236233
n,
237234
self.dim,
@@ -284,14 +281,11 @@ def log_likelihood(
284281
`mu` and covariance `sigma`.
285282
"""
286283
# Default to mu=self.mu0 and sigma=I
287-
if mu is None:
288-
mu = self.mu0
284+
mu = self.mu0 if mu is None else mu
289285
mu = torch.Tensor(mu).reshape(-1, self.ambient_dim).to(self.device)
290286
n = mu.shape[0]
291-
if sigma is None:
292-
sigma = torch.stack([torch.eye(self.dim)] * n).to(self.device)
293-
else:
294-
sigma = torch.Tensor(sigma).reshape(-1, self.dim, self.dim).to(self.device)
287+
sigma = torch.stack([torch.eye(self.dim)] * n).to(self.device) if sigma is None else sigma
288+
sigma = torch.Tensor(sigma).reshape(-1, self.dim, self.dim).to(self.device)
295289

296290
# Euclidean case is regular old Gaussian log-likelihood
297291
if self.type == "E":
@@ -336,8 +330,7 @@ def logmap(
336330
Returns:
337331
logmap_result: Tensor representing the result of the logarithmic map from `base` to `x` on the manifold.
338332
"""
339-
if base is None:
340-
base = self.mu0
333+
base = self.mu0 if base is None else base
341334
return self.manifold.logmap(x=base, y=x)
342335

343336
def expmap(
@@ -355,8 +348,7 @@ def expmap(
355348
Returns:
356349
expmap_result: Tensor representing the result of the exponential map applied to `u` at the base point.
357350
"""
358-
if base is None:
359-
base = self.mu0
351+
base = self.mu0 if base is None else base
360352
return self.manifold.expmap(x=base, u=u)
361353

362354
def stereographic(self, *points: Float[torch.Tensor, "n_points n_dim"]) -> tuple[Manifold, ...]:
@@ -633,18 +625,17 @@ def sample(
633625
x: Tensor of sampled points on the manifold
634626
v: Tensor of tangent vectors
635627
"""
636-
if z_mean is None:
637-
z_mean = self.mu0
628+
z_mean = self.mu0 if z_mean is None else z_mean
638629
z_mean = torch.Tensor(z_mean).reshape(-1, self.ambient_dim).to(self.device)
639630
n = z_mean.shape[0]
640631

641-
if sigma_factorized is None:
642-
sigma_factorized = [torch.stack([torch.eye(M.dim)] * n) for M in self.P]
643-
else:
644-
sigma_factorized = [
645-
torch.Tensor(sigma).reshape(-1, M.dim, M.dim).to(self.device)
646-
for M, sigma in zip(self.P, sigma_factorized, strict=False)
647-
]
632+
sigma_factorized = (
633+
[torch.stack([torch.eye(M.dim)] * n) for M in self.P] if sigma_factorized is None else sigma_factorized
634+
)
635+
sigma_factorized = [
636+
torch.Tensor(sigma).reshape(-1, M.dim, M.dim).to(self.device)
637+
for M, sigma in zip(self.P, sigma_factorized, strict=False)
638+
]
648639

649640
assert all(sigma.shape == (n, M.dim, M.dim) for M, sigma in zip(self.P, sigma_factorized, strict=False)), (
650641
"Sigma matrices must match the dimensions of the manifolds."
@@ -684,12 +675,12 @@ def log_likelihood(
684675
`mu` and covariance `sigma`.
685676
"""
686677
n = z.shape[0]
687-
if mu is None:
688-
mu = torch.vstack([self.mu0] * n).to(self.device)
678+
mu = torch.vstack([self.mu0] * n).to(self.device) if mu is None else mu
689679

690-
if sigma_factorized is None:
691-
sigma_factorized = [torch.stack([torch.eye(M.dim)] * n) for M in self.P]
692-
# Note that this factorization assumes block-diagonal covariance matrices
680+
sigma_factorized = (
681+
[torch.stack([torch.eye(M.dim)] * n) for M in self.P] if sigma_factorized is None else sigma_factorized
682+
)
683+
# Note that this factorization assumes block-diagonal covariance matrices
693684

694685
mu_factorized = self.factorize(mu)
695686
z_factorized = self.factorize(z)
@@ -807,10 +798,8 @@ def gaussian_mixture(
807798
torch.manual_seed(seed)
808799

809800
# Deal with clusters
810-
if num_clusters is None:
811-
num_clusters = num_classes
812-
else:
813-
assert num_clusters >= num_classes, "Number of clusters must be at least as large as number of classes."
801+
num_clusters = num_clusters or num_classes
802+
assert num_clusters >= num_classes, "Number of clusters must be at least as large as number of classes."
814803

815804
# Adjust covariance matrices for number of dimensions
816805
if adjust_for_dims:

manify/predictors/_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def __init__(
4747
self.pm = pm
4848
self.task = task
4949
self.random_state = random_state
50-
self.device = pm.device if device is None else device
50+
self.device = device or pm.device
5151
self.loss_history_: dict[str, list[float]] = {}
5252
self.is_fitted_: bool = False
5353

manify/predictors/_kernel.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@ def compute_kernel_and_norm_manifold(
2828
kernel_matrix: The kernel matrix between source and target points.
2929
norm_constant: Scalar normalization constant for the kernel.
3030
"""
31-
if X_target is None:
32-
X_target = X_source
31+
X_target = X_source if X_target is None else X_target
3332

3433
ip = manifold.inner(X_source, X_target)
3534
ip *= manifold.scale
@@ -77,9 +76,7 @@ def product_kernel(
7776
kernel_matrices: List of kernel matrices for each component manifold.
7877
norm_constants: List of normalization constants for each kernel.
7978
"""
80-
# If X_target is None, set it to X_source
81-
if X_target is None:
82-
X_target = X_source
79+
X_target = X_source if X_target is None else X_target
8380

8481
# Compute the kernel matrix and norm for each manifold
8582
Ks = []

manify/predictors/decision_tree.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -292,10 +292,7 @@ def __init__(
292292

293293
# Store hyperparameters
294294
self.pm = pm
295-
if max_depth is None:
296-
self.max_depth = -1 # This runs forever since the loop checks depth == 0
297-
else:
298-
self.max_depth = max_depth
295+
self.max_depth = max_depth or -1
299296
self.min_samples_leaf = min_samples_leaf
300297
self.min_samples_split = min_samples_split
301298
self.min_impurity_decrease = min_impurity_decrease
@@ -370,7 +367,7 @@ def _preprocess(
370367
dims = self.pm.man2dim[i]
371368

372369
# Non-Euclidean manifolds use angular projections
373-
if M.type in ["H", "S"]:
370+
if M.type in {"H", "S"}:
374371
if self.n_features == "d":
375372
dim = dims[0]
376373
num = X[:, dim : dim + 1]
@@ -516,7 +513,7 @@ def _aggregate_special_dims(
516513
) -> tuple[Float[torch.Tensor, "batch ambient_dim"], ProductManifold]:
517514
special_dims = []
518515
for i, M in enumerate(self.pm.P):
519-
if M.type in ["H", "S"]:
516+
if M.type in {"H", "S"}:
520517
dim = self.pm.man2dim[i][0]
521518
special_dims.append(X[:, dim : dim + 1])
522519
if len(special_dims) > 0:
@@ -655,10 +652,7 @@ def __init__(
655652
tree_kwargs: Dict[str, Any] = {}
656653
self.pm = tree_kwargs["pm"] = pm
657654
self.task = tree_kwargs["task"] = task
658-
if max_depth is None:
659-
self.max_depth = tree_kwargs["max_depth"] = -1
660-
else:
661-
self.max_depth = tree_kwargs["max_depth"] = max_depth
655+
self.max_depth = tree_kwargs["max_depth"] = max_depth or -1
662656
self.min_samples_leaf = tree_kwargs["min_samples_leaf"] = min_samples_leaf
663657
self.min_samples_split = tree_kwargs["min_samples_split"] = min_samples_split
664658
self.min_impurity_decrease = tree_kwargs["min_impurity_decrease"] = min_impurity_decrease

manify/predictors/nn/layers.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,11 @@ def __init__(
3535
):
3636
super().__init__()
3737

38-
# Parameters are Euclidean, straightforardly
39-
# self.W = torch.rand(in_features, out_features)
38+
# Parameters are Euclidean, straightforwardly
4039
self.W = torch.nn.Parameter(torch.randn(in_features, out_features) * 0.01)
41-
# self.b = torch.nn.Parameter(torch.rand(out_features))
4240

43-
# Noninearity must be applied via the manifold
44-
if nonlinearity is None:
45-
self.sigma = lambda x: x
46-
else:
47-
self.sigma = lambda x: manifold.expmap(nonlinearity(manifold.logmap(x)))
41+
# Nonlinearity must be applied via the manifold
42+
self.sigma = manifold.apply(nonlinearity) if nonlinearity else lambda x: x
4843

4944
# Also store manifold
5045
self.manifold = manifold

manify/predictors/perceptron.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,8 @@ def __init__(
5252
self.pm = pm # ProductManifold instance
5353
self.max_epochs = max_epochs
5454
self.patience = patience # Number of consecutive epochs without improvement to consider convergence
55-
if weights is None:
56-
self.weights = torch.ones(len(pm.P), dtype=torch.float32)
57-
else:
58-
assert len(weights) == len(pm.P), "Number of weights must match the number of manifolds."
59-
self.weights = weights
55+
self.weights = torch.ones(len(pm.P), dtype=torch.float32) if weights is None else weights
56+
assert len(self.weights) == len(pm.P), "Number of weights must match the number of manifolds."
6057

6158
def fit(
6259
self, X: Float[torch.Tensor, "n_samples n_manifolds"], y: Int[torch.Tensor, "n_samples"]

manify/predictors/svm.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def __init__(
8282
self.eps = epsilon
8383
self.task = task
8484
self.weights = torch.ones(len(pm.P), dtype=torch.float32) if weights is None else weights
85-
assert len(self.weights) == len(pm.P), "Number of weights must match manifolds."
85+
assert len(self.weights) == len(pm.P), "Number of weights must match the number of manifolds."
8686

8787
def fit(
8888
self,
@@ -99,7 +99,6 @@ def fit(
9999
self: Fitted ProductSpaceSVM instance.
100100
"""
101101
# unique classes
102-
# self.classes_ = torch.unique(y).tolist()
103102
self._store_classes(y)
104103
n = X.shape[0]
105104

manify/utils/benchmarks.py

Lines changed: 36 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,14 @@ def _score(
6161
use_torch: bool = False,
6262
score: list[SCORETYPE] | None = None,
6363
) -> dict[SCORETYPE, float]:
64-
if score is None:
65-
score = ["accuracy", "f1-micro"]
66-
if y_pred_override is not None:
67-
y_pred = y_pred_override
68-
else:
69-
assert model is not None, "Model must be provided if y_pred_override is not given"
70-
y_pred = model.predict(_X)
64+
"""Helper function to score a model."""
65+
score = score or ["accuracy", "f1-micro"]
66+
assert model is not None or y_pred_override is not None, "Model must be provided if y_pred_override is not given"
67+
y_pred = y_pred_override if y_pred_override is not None else model.predict(_X) # type: ignore
68+
7169
if use_torch:
7270
y_pred = y_pred.detach().cpu().numpy()
71+
7372
scoring_funcs = {
7473
"accuracy": accuracy_score,
7574
"f1-micro": lambda y, p: f1_score(y, p, average="micro"),
@@ -164,42 +163,40 @@ def benchmark(
164163
Returns:
165164
Dictionary mapping model names to their corresponding evaluation scores.
166165
"""
167-
if score is None:
168-
score = ["accuracy", "f1-micro", "f1-macro"]
169-
if models is None:
170-
models = [
171-
"sklearn_dt",
172-
"sklearn_rf",
173-
"product_dt",
174-
"product_rf",
175-
"tangent_dt",
176-
"tangent_rf",
177-
"knn",
178-
"ps_perceptron",
179-
# "svm",
180-
"ps_svm",
181-
# "tangent_mlp",
182-
"ambient_mlp",
183-
"tangent_gcn",
184-
"ambient_gcn",
185-
"kappa_gcn",
186-
"ambient_mlr",
187-
"tangent_mlr",
188-
"kappa_mlr",
189-
"single_manifold_rf",
190-
]
166+
score = score or ["accuracy", "f1-micro", "f1-macro"]
167+
models = models or [
168+
"sklearn_dt",
169+
"sklearn_rf",
170+
"product_dt",
171+
"product_rf",
172+
"tangent_dt",
173+
"tangent_rf",
174+
"knn",
175+
"ps_perceptron",
176+
"svm",
177+
"ps_svm",
178+
"tangent_mlp",
179+
"ambient_mlp",
180+
"tangent_gcn",
181+
"ambient_gcn",
182+
"kappa_gcn",
183+
"ambient_mlr",
184+
"tangent_mlr",
185+
"kappa_mlr",
186+
"single_manifold_rf",
187+
]
191188

192189
# Input validation on (task, score) pairing
193-
if task in ["classification", "link_prediction"]:
194-
assert all(s in ["accuracy", "f1-micro", "f1-macro", "time"] for s in score)
190+
if task in {"classification", "link_prediction"}:
191+
assert all(s in {"accuracy", "f1-micro", "f1-macro", "time"} for s in score)
195192
elif task == "regression":
196-
assert all(s in ["mse", "rmse", "percent_rmse", "time"] for s in score)
193+
assert all(s in {"mse", "rmse", "percent_rmse", "time"} for s in score)
197194

198195
# Input validation on (task, score) pairing
199-
if task in ["classification", "link_prediction"]:
200-
assert all(s in ["accuracy", "f1-micro", "f1-macro", "time"] for s in score)
196+
if task in {"classification", "link_prediction"}:
197+
assert all(s in {"accuracy", "f1-micro", "f1-macro", "time"} for s in score)
201198
elif task == "regression":
202-
assert all(s in ["mse", "rmse", "percent_rmse", "time"] for s in score)
199+
assert all(s in {"mse", "rmse", "percent_rmse", "time"} for s in score)
203200
else:
204201
raise ValueError(f"Unknown task: {task}")
205202

@@ -243,7 +240,7 @@ def benchmark(
243240
X_train, X_test, y_train, y_test, train_idx, test_idx = train_test_split(X, y, np.arange(len(X)), test_size=0.2)
244241

245242
# Make sure classification labels are formatted correctly
246-
if task in ["classification", "link_prediction"]:
243+
if task in {"classification", "link_prediction"}:
247244
y = torch.unique(y, return_inverse=True)[1]
248245
y_train = y[train_idx]
249246
y_test = y[test_idx]
@@ -302,7 +299,7 @@ def benchmark(
302299
nn_train_kwargs = {"epochs": epochs, "lr": lr}
303300

304301
# Define your models
305-
if task in ["classification", "link_prediction"]:
302+
if task in {"classification", "link_prediction"}:
306303
dt_class = DecisionTreeClassifier
307304
rf_class = RandomForestClassifier
308305
knn_class = KNeighborsClassifier

0 commit comments

Comments
 (0)