Skip to content

Commit 384e17d

Browse files
committed
More typing fixes, mostly
1 parent 8f585d2 commit 384e17d

20 files changed

+128
-118
lines changed

.github/workflows/test.yml

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,9 @@ jobs:
2828
- name: Install dependencies
2929
run: |
3030
python -m pip install --upgrade pip
31-
# Install base manify
3231
pip install -e .
33-
# install manify-dev
3432
pip install -e ".[dev]"
35-
# pip install -e ".[utils]"
36-
# install all extras
33+
pip install -e ".[utils]"
3734
pip install -e ".[all]"
3835
3936
- name: Smoke‑test pip install manify
@@ -45,31 +42,25 @@ jobs:
4542
4643
# Code quality checks
4744
- name: Check code formatting with Black
48-
run: |
49-
black --check manify/ --line-length 120
45+
run: black --check manify/ --line-length 120
5046
continue-on-error: true
5147

5248
- name: Check import ordering with isort
53-
run: |
54-
isort --check-only --profile black manify/
49+
run: isort --check-only --profile black manify/
5550
continue-on-error: true
5651

5752
- name: Run pylint
58-
run: |
59-
pylint manify/
53+
run: pylint manify/
6054
continue-on-error: true
6155

6256
# Type checking
6357
- name: Check type annotations with MyPy
64-
run: |
65-
# only check core modules—skip untyped imports
66-
mypy --ignore-missing-imports --no-warn-return-any --follow-imports=skip manify/
58+
run: mypy manify/
6759
continue-on-error: true
6860

6961
# Unit testing
7062
- name: Run unit tests & collect coverage
71-
run: |
72-
pytest tests --cov=manify --cov-report=xml:coverage.xml
63+
run: pytest tests --cov=manify --cov-report=xml:coverage.xml
7364

7465
# Code coverage
7566
- name: Upload coverage to Codecov

manify/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
"""Manify: A Python Library for Learning Non-Euclidean Representations."""
2+
13
import manify.curvature_estimation
24
import manify.embedders
35
import manify.manifolds

manify/curvature_estimation/delta_hyperbolicity.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22

33
from __future__ import annotations
44

5+
from typing import Tuple
6+
57
import torch
68
from jaxtyping import Float
7-
from typing import Tuple
89

910

1011
def sampled_delta_hyperbolicity(

manify/curvature_estimation/greedy_method.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from __future__ import annotations
44

5-
from typing import Tuple, Any
5+
from typing import Any, Tuple
66

77
import torch
88

manify/curvature_estimation/sectional_curvature.py

Lines changed: 71 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -2,76 +2,81 @@
22

33
import random
44

5+
import networkx as nx
56
import numpy as np
67

78
# The next couple functions are taken from this repo:
89
# https://github.com/HazyResearch/hyperbolics
910
# Paper: https://openreview.net/pdf?id=HJxeWnCcF7
1011

1112

12-
def Ka(D, m, b, c, a):
13-
if a == m:
14-
return 0.0
15-
k = D[a][m] ** 2 + D[b][c] ** 2 / 4.0 - (D[a][b] ** 2 + D[a][c] ** 2) / 2.0
16-
k /= 2 * D[a][m]
17-
return k
18-
19-
20-
def K(D, n, m, b, c):
21-
ks = [Ka(D, m, b, c, a) for a in range(n)]
22-
return np.mean(ks)
23-
24-
25-
def ref(D, size, n, m, b, c):
26-
ks = []
27-
for i in range(n):
28-
a = random.randint(0, size - 1)
29-
if a == b or a == c:
30-
continue
31-
else:
32-
ks.append(Ka(D, m, b, c, a))
33-
return np.mean(ks)
34-
35-
36-
def estimate_curvature(G, D, n):
37-
for m in range(n):
38-
ks = []
39-
edges = list(G.edges(m))
40-
for i in range(len(edges)):
41-
for j in range(b, len(edges)):
42-
b = edges[i]
43-
c = edges[j]
44-
ks.append(K(D, n, b, c))
45-
return None
46-
47-
48-
def sample(D, size, n_samples=100):
49-
samples = []
50-
_cnt = 0
51-
while _cnt < n_samples:
52-
a, b, c, m = random.sample(range(0, size), 4)
53-
k = Ka(D, m, b, c, a)
54-
samples.append(k)
55-
56-
_cnt += 1
57-
58-
return np.array(samples)
59-
60-
61-
def estimate(D, size, n_samples):
62-
samples = sample(D, size, n_samples)
63-
m1 = np.mean(samples)
64-
m2 = np.mean(samples**2)
65-
return samples
66-
67-
68-
def estimate_diff(D, size, n_sample, num):
69-
samples = []
70-
_cnt = 0
71-
while _cnt < n_sample:
72-
b, c, m = random.sample(range(0, size), 3)
73-
k = ref(D, size, num, m, b, c)
74-
# k=K(D, n, m, b, c)
75-
samples.append(k)
76-
_cnt += 1
77-
return np.array(samples)
13+
def sectional_curvature(G: nx.Graph) -> np.ndarray:
14+
raise NotImplementedError
15+
16+
17+
# def Ka(D: np.ndarray, m: int, b: int, c: int, a: int) -> float:
18+
# if a == m:
19+
# return 0.0
20+
# k = D[a][m] ** 2 + D[b][c] ** 2 / 4.0 - (D[a][b] ** 2 + D[a][c] ** 2) / 2.0
21+
# k /= 2 * D[a][m]
22+
# return float(k)
23+
24+
25+
# def K(D: np.ndarray, n: int, m: int, b: int, c: int) -> float:
26+
# ks = [Ka(D, m, b, c, a) for a in range(n)]
27+
# return float(np.mean(ks))
28+
29+
30+
# def ref(D: np.ndarray, size: int, n: int, m: int, b: int, c: int) -> float:
31+
# ks = []
32+
# for i in range(n):
33+
# a = random.randint(0, size - 1)
34+
# if a == b or a == c:
35+
# continue
36+
# else:
37+
# ks.append(Ka(D, m, b, c, a))
38+
# return float(np.mean(ks))
39+
40+
41+
# def estimate_curvature(G: nx.Graph, D: np.ndarray, n: int) -> None:
42+
# for m in range(n):
43+
# ks = []
44+
# edges = list(G.edges(m))
45+
# for i in range(len(edges)):
46+
# for j in range(i + 1, len(edges)):
47+
# b = edges[i]
48+
# c = edges[j]
49+
# ks.append(K(D, n, m, b, c))
50+
# return None
51+
52+
53+
# def sample(D: np.ndarray, size: int, n_samples: int = 100) -> np.ndarray:
54+
# samples = []
55+
# _cnt = 0
56+
# while _cnt < n_samples:
57+
# a, b, c, m = random.sample(range(0, size), 4)
58+
# k = Ka(D, m, b, c, a)
59+
# samples.append(k)
60+
61+
# _cnt += 1
62+
63+
# return np.array(samples)
64+
65+
66+
# def estimate(D: np.ndarray, size: int, n_samples: int) -> np.ndarray:
67+
# samples = sample(D, size, n_samples)
68+
# m1 = np.mean(samples)
69+
# m2 = np.mean(samples**2)
70+
# return samples
71+
72+
73+
# def estimate_diff(D: np.ndarray, size: int, n_sample: int, num: int) -> np.ndarray:
74+
# samples = []
75+
# _cnt = 0
76+
# while _cnt < n_sample:
77+
# b, c, m = random.sample(range(0, size), 3)
78+
# k = ref(D, size, num, m, b, c)
79+
# # k=K(D, n, m, b, c)
80+
# samples.append(k)
81+
# _cnt += 1
82+
# return np.array(samples)

manify/embedders/coordinate_learning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from jaxtyping import Float, Int
1212

1313
from ..manifolds import ProductManifold
14-
from .losses import d_avg, distortion_loss
14+
from ._losses import d_avg, distortion_loss
1515

1616
# TQDM: notebook or regular
1717
if "ipykernel" in sys.modules:

manify/manifolds.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
"""
2-
Tools for generating Riemannian manifolds and product manifolds.
1+
"""Tools for generating Riemannian manifolds and product manifolds.
32
4-
The module consists of two classes: Manifold and ProductManifold .The Manifold class
5-
represents hyperbolic, Euclidean, or spherical manifolds based on curvature.
6-
The ProductManifold class supports products of multiple manifolds,
7-
combining their geometric properties to create mixed-curvature. Both classes
8-
includes functions for different key geometric operations.
3+
The module consists of two classes: `Manifold` and `ProductManifold`. The `Manifold` class represents hyperbolic,
4+
Euclidean, or spherical manifolds of constant Gaussian curvature. The `ProductManifold` class supports Cartesian
5+
products of multiple manifolds, combining their geometric properties to create mixed-curvature. Both classes
6+
include methods for different key geometric operations, and are built on top of their corresponding `geoopt` classes
7+
(`Lorentz`, `Euclidean`, `Sphere`, `Scaled` and `ProductManifold`)
98
"""
109

1110
from __future__ import annotations
@@ -21,8 +20,7 @@
2120

2221

2322
class Manifold:
24-
"""
25-
Tools for generating Riemannian manifolds.
23+
"""Tools for generating Riemannian manifolds.
2624
2725
Parameters
2826
----------

0 commit comments

Comments
 (0)