-
Notifications
You must be signed in to change notification settings - Fork 16
Open
Description
It would be useful to also see the performance of each SSL model against the purely supervised backbone run on the labeled data.
For example, TSVM vs pure SVM:
import numpy as np
from LAMDA_SSL.Dataset.Tabular.BreastCancer import BreastCancer
dataset = BreastCancer(labeled_size=0.1, stratified=True, shuffle=True)
labeled_X = dataset.labeled_X
labeled_y = dataset.labeled_y
unlabeled_X = dataset.unlabeled_X
unlabeled_y = dataset.unlabeled_y
from sklearn import preprocessing
pre_transform = preprocessing.StandardScaler()
pre_transform.fit(np.vstack([labeled_X, unlabeled_X]))
labeled_X = pre_transform.transform(labeled_X)
unlabeled_X = pre_transform.transform(unlabeled_X)
from LAMDA_SSL.Algorithm.Classification.TSVM import TSVM
# I tried using a range of Cl and Cu, starting from 15 and 0.0001 and then gradually
# upping Cu and decreasing Cl. It didn't seem to make a difference?
model = TSVM(Cl=1, Cu=1, kernel="linear")
model.fit(X=labeled_X, y=labeled_y, unlabeled_X=unlabeled_X)
pred_y = model.predict()
from LAMDA_SSL.Evaluation.Classifier.Accuracy import Accuracy
score = Accuracy().scoring(unlabeled_y, pred_y)
print(f"SSL TSVM score: {score}")
#> SSL TSVM score: 0.9609375
# Compare with pure SVM
from sklearn import svm
model_sl = svm.SVC()
model_sl.fit(labeled_X, labeled_y)
pred_sl = model_sl.predict(unlabeled_X)
score_sl = Accuracy().scoring(unlabeled_y, pred_sl)
print(f"SL SVM score: {score_sl}")
#> SL SVM score: 0.955078125Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels