Skip to content

Commit 73fa974

Browse files
committed
Added test cases
1 parent b8c4f75 commit 73fa974

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

tests/test_curvature_estimation.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
import torch
22

3-
from manify.manifolds import ProductManifold
3+
from manify.curvature_estimation._pipelines import (
4+
distortion_pipeline,
5+
predictor_pipeline,
6+
)
47
from manify.curvature_estimation.delta_hyperbolicity import delta_hyperbolicity
58
from manify.curvature_estimation.sectional_curvature import sectional_curvature
6-
from manify.curvature_estimation._pipelines import distortion_pipeline, predictor_pipeline
7-
from manify.curvature_estimation.delta_hyperbolicity import sampled_delta_hyperbolicity, vectorized_delta_hyperbolicity
89
from manify.curvature_estimation.greedy_method import greedy_signature_selection
10+
from manify.manifolds import ProductManifold
11+
from manify.utils.dataloaders import load_hf
912

1013

1114
def test_delta_hyperbolicity():
1215
torch.manual_seed(42)
1316
pm = ProductManifold(signature=[(-1.0, 2)])
14-
X, _ = pm.sample(z_mean=torch.stack([pm.mu0] * 10))
17+
X = pm.sample(z_mean=torch.stack([pm.mu0] * 10))
1518
dists = pm.pdist(X)
1619

1720
# Test sampled method

0 commit comments

Comments
 (0)