Skip to content

Commit bfb8007

Browse files
committed
Update documentation
1 parent f545b5b commit bfb8007

File tree

10 files changed

+298
-179
lines changed

10 files changed

+298
-179
lines changed

docs/api.md

Lines changed: 0 additions & 5 deletions
This file was deleted.

docs/gen_ref_pages.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,13 @@
6666
with mkdocs_gen_files.open(f"{doc_dir}/index.md", "w") as fd:
6767
fd.write("# API Reference\n\n")
6868
fd.write("## Overview\n\n")
69+
70+
# Only show modules, not individual classes
6971
fd.write(f"::: {src_dir}\n")
7072
fd.write(" options:\n")
7173
fd.write(" show_category_heading: false\n")
7274
fd.write(" members_order: source\n")
7375
fd.write(" filters: ['!^_', '!^Parameters$']\n")
7476
fd.write(" show_root_heading: false\n")
7577
fd.write(" heading_level: 3\n")
78+
fd.write(" members: false\n") # This line prevents showing class members

docs/index.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
11
# Welcome to Manify
22

3-
A library for geometric ML with manifold-based methods.
4-
5-
- 📚 API Reference: [API](api.md)
3+
A Python Library for Learning Non-Euclidean Representations

manify/embedders/_losses.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,18 @@ def distortion_loss(
1616
D_true: Float[torch.Tensor, "n_points n_points"],
1717
pairwise: bool = False,
1818
) -> Float[torch.Tensor, ""]:
19-
"""
20-
Compute the distortion loss between estimated SQUARED distances and true SQUARED distances.
19+
"""Compute the distortion loss between estimated SQUARED distances and true SQUARED distances.
20+
2121
Args:
22-
D_est (n_points, n_points): A tensor of estimated pairwise distances.
23-
D_true (n_points, n_points).: A tensor of true pairwise distances.
24-
pairwise (bool): A boolean indicating whether to return whether D_est and D_true are pairwise
22+
D_est: A tensor of estimated pairwise distances.
23+
D_true: A tensor of true pairwise distances.
24+
pairwise: A boolean indicating whether to return whether D_est and D_true are pairwise
2525
2626
Returns:
2727
float: A float indicating the distortion loss, calculated as the sum of the squared relative
28-
errors between the estimated and true squared distances.
28+
errors between the estimated and true squared distances.
2929
30-
See also: square_loss in HazyResearch hyperbolics repo:
30+
See also: `square_loss` in HazyResearch hyperbolics repo:
3131
https://github.com/HazyResearch/hyperbolics/blob/master/pytorch/hyperbolic_models.py#L178
3232
"""
3333

manify/manifolds.py

Lines changed: 206 additions & 116 deletions
Large diffs are not rendered by default.

manify/predictors/_kernel.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,13 @@ def compute_kernel_and_norm_manifold(
2020
2121
Args:
2222
manifold: The manifold in which the computation occurs.
23-
X_source((n_points_source, n_dim)): A tensor of the source points
24-
X_target("n_points_target", "n_dim"): A tensor of target points
23+
X_source: A tensor of the source points
24+
X_target: A tensor of target points
2525
2626
Return:
27-
Tuple("n_points_source", "n_points_target"): A tuple of two tensors. The first tensor
28-
is the kernel matrix of shape computed based on the manifold type. And the second tensor
29-
A scalar normalization constant for the kernel, determined by the manifold's curvature or scale.
27+
A tuple of two tensors. The first tensor is the kernel matrix of shape computed based on the manifold type.
28+
And the second tensor is a scalar normalization constant for the kernel, determined by the manifold's curvature
29+
or scale.
3030
"""
3131
if X_target is None:
3232
X_target = X_source
@@ -72,13 +72,13 @@ def product_kernel(
7272
7373
Args:
7474
pm: The product manifold in which the computation occurs.
75-
X_source((n_points_source, n_dim)): A tensor of the source points
76-
X_target("n_points_target", "n_dim"): A tensor of target points
75+
X_source: A tensor of the source points
76+
X_target: A tensor of target points
7777
7878
Returns:
79-
Tuple("n_points_source", "n_points_target"): A tuple of two tensors. The first tensor is the
80-
kernel matrix of shape computed based on the product manifold type. And the second tensor is a
81-
scalar normalization constant for the kernel, determined by the product manifold's curvature or scale.
79+
A tuple of two tensors. The first tensor is the kernel matrix of shape computed based on the product manifold
80+
type. And the second tensor is a scalar normalization constant for the kernel, determined by the product
81+
manifold's curvature or scale.
8282
"""
8383
# If X_target is None, set it to X_source
8484
if X_target is None:

manify/predictors/_midpoint.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def hyperbolic_midpoint(u: float, v: float, assert_hyperbolic: bool = False) ->
1616
u: The first angular coordinate.
1717
v: The second angular coordinate.
1818
assert_hyperbolic: A boolean value. If True, verifies that the midpoint satisfies the hyperbolic
19-
distance property. Defaults to False.
19+
distance property. Defaults to False.
2020
2121
Returns:
2222
torch.Tensor: The computed hyperbolic midpoint between u and v.
@@ -81,22 +81,20 @@ def midpoint(
8181
manifold: Manifold,
8282
special_first: bool = False,
8383
) -> Float[torch.Tensor, ""]:
84-
"""
85-
Driver code to compute the midpoint between two angular coordinates give the manifold type.
84+
"""Compute the midpoint between two angular coordinates given the manifold type.
8685
8786
This function automatically selects the appropriate midpoint calculation depending
8887
on the manifold type. It supports hyperbolic, Euclidean, and spherical geometries.
8988
9089
Args:
9190
u: The first angular coordinate.
9291
v: The second angular coordinate.
93-
manifold (Manifold): An object representing the manifold type.
94-
special_first (bool, optional): If True, uses the manifold-specific midpoint
95-
calculations given the manifold type of hyperbolic or euclidean. Defaults to False.
92+
manifold: An object representing the manifold type.
93+
special_first: If True, uses the manifold-specific midpoint calculations given the manifold type of hyperbolic
94+
or euclidean. Defaults to False.
9695
9796
Returns:
9897
torch.Tensor: The computed midpoint between u and v, based on the selected geometry.
99-
10098
"""
10199
if torch.isclose(u, v):
102100
return u

manify/predictors/kappa_gcn.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -377,8 +377,7 @@ def fit(
377377
Args:
378378
X (torch.Tensor): Feature matrix.
379379
y (torch.Tensor): Labels for training nodes.
380-
adj (torch.Tensor): Adjacency or distance matrix.
381-
train_idx (torch.Tensor): Indices of nodes for training.
380+
A (torch.Tensor): Adjacency or distance matrix.
382381
epochs: Number of training epochs (default=200).
383382
lr: Learning rate (default=1e-2).
384383
use_tqdm: Whether to use tqdm for progress bar.

manify/utils/benchmarks.py

Lines changed: 54 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -103,39 +103,65 @@ def benchmark(
103103
lp_train_idx: Optional[Float[torch.Tensor, "n_samples,"]] = None,
104104
lp_test_idx: Optional[Float[torch.Tensor, "n_samples,"]] = None,
105105
) -> Dict[str, float]:
106-
"""
107-
Benchmarks various machine learning models on a dataset using a product manifold structure.
106+
"""Benchmarks various machine learning models on Riemannian manifold datasets.
107+
108+
Evaluates and compares different machine learning models on datasets with a
109+
product manifold structure, providing metrics for their performance.
108110
109111
Args:
110-
X (batch, dim): Input tensor of features
111-
y (batch,): Input tensor of labels.
112-
pm: The defined product manifold for benchmarks.
113-
split: Data splitting strategy ('train_test' or 'cross_val').
114-
device: Device for computation ('cpu', 'cuda', 'mps').
115-
score: Scoring metric for model evaluation ('accuracy', 'f1-micro', etc.).
112+
X: Tensor of input features with shape (batch, dim).
113+
y: Tensor of target labels with shape (batch,).
114+
pm: ProductManifold object defining the geometric structure for benchmarks.
115+
device: Device for computation. Options: 'cpu', 'cuda', 'mps'. Defaults to 'cpu'.
116+
score: List of scoring metrics for model evaluation (e.g., 'accuracy', 'f1-micro').
117+
Defaults to None.
116118
models: List of model names to evaluate. Options include:
117-
* "sklearn_dt": Decision tree from scikit-learn.
118-
* "sklearn_rf": Random forest from scikit-learn.
119-
* "product_dt": Product space decision tree.
120-
* "product_rf": Product space random forest.
121-
* "tangent_dt": Decision tree on tangent space.
122-
* "tangent_rf": Random forest on tangent space.
123-
* "knn": k-nearest neighbors.
124-
* "ps_perceptron": Product space perceptron.
125-
max_depth: Maximum depth of tree-based models in integer.
126-
n_estimators: Integer number of estimators for random forest models.
127-
min_samples_split: Minimum number of samples required to split an internal node.
128-
min_samples_leaf: Minimum number of samples in a leaf node.
129-
task: Task type ('classification' or 'regression').
130-
seed: Random seed for reproducibility.
131-
use_special_dims: Boolean for whether to use special manifold dimensions.
132-
n_features: Feature dimensionality type ('d' or 'd_choose_2').
133-
X_train, X_test, y_train, y_test: Training and testing datasets, X: feature, y: label.
134-
batch_size: Batch size for certain models.
119+
* "sklearn_dt": Decision tree from scikit-learn
120+
* "sklearn_rf": Random forest from scikit-learn
121+
* "product_dt": Product space decision tree
122+
* "product_rf": Product space random forest
123+
* "tangent_dt": Decision tree on tangent space
124+
* "tangent_rf": Random forest on tangent space
125+
* "knn": k-nearest neighbors
126+
* "ps_perceptron": Product space perceptron
127+
Defaults to None.
128+
max_depth: Maximum depth of tree-based models. Defaults to 5.
129+
n_estimators: Number of estimators for ensemble models. Defaults to 12.
130+
min_samples_split: Minimum samples required to split an internal node. Defaults to 2.
131+
min_samples_leaf: Minimum samples required in a leaf node. Defaults to 1.
132+
task: Type of machine learning task. Options: 'classification' or 'regression'.
133+
Defaults to 'classification'.
134+
seed: Random seed for reproducibility. Defaults to None.
135+
use_special_dims: Whether to use special manifold dimensions. Defaults to False.
136+
n_features: Feature dimensionality type. Options: 'd' or 'd_choose_2'.
137+
Defaults to 'd_choose_2'.
138+
X_train: Training feature tensor with shape (n_samples, n_manifolds).
139+
If provided, overrides split from X. Defaults to None.
140+
X_test: Testing feature tensor with shape (n_samples, n_manifolds).
141+
If provided, used with X_train. Defaults to None.
142+
y_train: Training labels tensor with shape (n_samples,).
143+
Must be provided if X_train is given. Defaults to None.
144+
y_test: Testing labels tensor with shape (n_samples,).
145+
Must be provided if X_test is given. Defaults to None.
146+
batch_size: Batch size for neural network models. Defaults to None.
147+
adj: Adjacency matrix for graph-based models with shape (n_nodes, n_nodes).
148+
Defaults to None.
149+
A_train: Training adjacency matrix with shape (n_samples, n_samples).
150+
Defaults to None.
151+
A_test: Testing adjacency matrix with shape (n_samples, n_samples).
152+
Defaults to None.
153+
hidden_dims: List of hidden layer dimensions for neural networks.
154+
Defaults to None.
155+
epochs: Number of training epochs for iterative models. Defaults to 4000.
156+
lr: Learning rate for gradient-based optimization. Defaults to 1e-4.
157+
kappa_gcn_layers: Number of layers in GCN models. Defaults to 1.
158+
lp_train_idx: Training indices for link prediction with shape (n_samples,).
159+
Defaults to None.
160+
lp_test_idx: Testing indices for link prediction with shape (n_samples,).
161+
Defaults to None.
135162
136163
Returns:
137-
Dict[str, float]: Dictionary of model names and their corresponding evaluation scores.
138-
164+
Dictionary mapping model names to their corresponding evaluation scores.
139165
"""
140166
if score is None:
141167
score = ["accuracy", "f1-micro", "f1-macro"]

mkdocs.yml

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ plugins:
2424
# Explicitly define navigation structure
2525
nav:
2626
- Home: index.md
27+
- Installation: installation.md
2728
- API Reference:
2829
- Overview: reference/index.md
2930
- Manifolds: reference/manifolds.md
@@ -37,12 +38,15 @@ nav:
3738
- Coordinate Learning: reference/embedders/coordinate_learning.md
3839
- Siamese: reference/embedders/siamese.md
3940
- VAE: reference/embedders/vae.md
41+
- Losses: reference/embedders/_losses.md
4042
- Predictors:
4143
- Overview: reference/predictors/index.md
4244
- Decision Tree: reference/predictors/decision_tree.md
4345
- Kappa GCN: reference/predictors/kappa_gcn.md
4446
- Perceptron: reference/predictors/perceptron.md
4547
- SVM: reference/predictors/svm.md
48+
- Kernel: reference/predictors/_kernel.md
49+
- Midpoint: reference/predictors/_midpoint.md
4650
- Utils:
4751
- Overview: reference/utils/index.md
4852
- Benchmarks: reference/utils/benchmarks.md
@@ -56,4 +60,10 @@ markdown_extensions:
5660
- pymdownx.highlight
5761
- pymdownx.superfences
5862
- toc:
59-
permalink: true
63+
permalink: true
64+
- pymdownx.arithmatex:
65+
generic: true
66+
67+
extra_javascript:
68+
- https://polyfill.io/v3/polyfill.min.js?features=es6
69+
- https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js

0 commit comments

Comments
 (0)