Skip to content

Commit ce9b5bc

Browse files
committed
add cluster
1 parent 4235942 commit ce9b5bc

File tree

2 files changed

+38
-4
lines changed

2 files changed

+38
-4
lines changed

dance/modules/base.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,3 +166,34 @@ class BaseRegressionMethod(BaseMethod):
166166
class BaseClusteringMethod(BaseMethod):
167167

168168
_DEFAULT_METRIC = "ari"
169+
170+
def score(self, x, y, *, score_func: Optional[Union[str, Mapping[Any, float]]] = None, return_pred: bool = False,
171+
valid_idx=None, test_idx=None) -> Union[float, Tuple[float, Any]]:
172+
y_pred = self.predict(x)
173+
func = resolve_score_func(score_func or self._DEFAULT_METRIC)
174+
if valid_idx is None:
175+
score = func(y, y_pred)
176+
return (score, y_pred) if return_pred else score
177+
else:
178+
valid_score = func([y[i] for i in valid_idx], [y_pred[i] for i in valid_idx])
179+
test_score = func([y[i] for i in test_idx], [y_pred[i] for i in test_idx])
180+
return ({
181+
"valid_score": valid_score,
182+
"test_score": test_score
183+
}, y_pred) if return_pred else {
184+
"valid_score": valid_score,
185+
"test_score": test_score
186+
}
187+
188+
def fit_score(self, x, y, *, score_func: Optional[Union[str, Mapping[Any,
189+
float]]] = None, return_pred: bool = False,
190+
valid_idx=None, test_idx=None, **fit_kwargs) -> Union[float, Tuple[float, Any]]:
191+
"""Shortcut for fitting data using the input feature and return eval.
192+
193+
Note
194+
----
195+
Only work for models where the fitting does not require labeled data, i.e. unsupervised methods.
196+
197+
"""
198+
self.fit(x, **fit_kwargs)
199+
return self.score(x, y, score_func=score_func, return_pred=return_pred, valid_idx=valid_idx, test_idx=test_idx)

examples/tuning/cluster_graphsc/main.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22
import os
33
import pprint
44
import sys
5+
from cgi import test
56
from pathlib import Path
67

78
import numpy as np
89
import torch
9-
import wandb
10+
from sklearn.model_selection import train_test_split
1011

12+
import wandb
1113
from dance import logger
1214
from dance.datasets.singlemodality import ClusteringDataset
1315
from dance.modules.single_modality.clustering.graphsc import GraphSC
@@ -74,7 +76,8 @@ def evaluate_pipeline(tune_mode=args.tune_mode, pipeline_planer=pipeline_planer)
7476
preprocessing_pipeline = pipeline_planer.generate(**kwargs)
7577
print(f"Pipeline config:\n{preprocessing_pipeline.to_yaml()}")
7678
preprocessing_pipeline(data)
77-
79+
total_idx = range(data.shape[0])
80+
valid_idx, test_idx = train_test_split(total_idx, test_size=0.9, random_state=args.seed)
7881
graph, y = data.get_train_data()
7982
n_clusters = len(np.unique(y))
8083

@@ -91,8 +94,8 @@ def evaluate_pipeline(tune_mode=args.tune_mode, pipeline_planer=pipeline_planer)
9194
num_workers=args.num_workers, device=args.device)
9295
model.fit(graph, epochs=args.epochs, lr=args.learning_rate, show_epoch_ari=args.show_epoch_ari,
9396
eval_epoch=args.eval_epoch)
94-
score = model.score(None, y)
95-
wandb.log({"acc": score})
97+
valid_score, test_score = model.score(graph, y, valid_idx=valid_idx, test_idx=test_idx)
98+
wandb.log({"ari": valid_score, "test_ari": test_score})
9699
wandb.finish()
97100
del model
98101
torch.cuda.empty_cache()

0 commit comments

Comments
 (0)