We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 773bb0e commit 9fcc31dCopy full SHA for 9fcc31d
manify/utils/benchmarks.py
@@ -25,12 +25,7 @@
25
from ..predictors.perceptron import ProductSpacePerceptron
26
from ..predictors.svm import ProductSpaceSVM
27
28
-# from ..predictors.decision_tree import ProductSpaceDT, ProductSpaceRF
29
-from ..predictors.tree_icml import (
30
- ProductSpaceDT,
31
- ProductSpaceRF,
32
- SingleManifoldEnsembleRF,
33
-)
+from ..predictors.decision_tree import ProductSpaceDT, ProductSpaceRF
34
35
36
def _score(
@@ -254,7 +249,7 @@ def benchmark(
254
249
if adj is not None:
255
250
A_hat = get_A_hat(adj).detach()
256
251
else:
257
- dists = pdists ** 2
252
+ dists = pdists**2
258
253
dists_train = dists[train_idx][:, train_idx]
259
dists /= dists_train[torch.isfinite(dists_train)].max()
260
A_hat = get_A_hat(dists).detach()
0 commit comments