Skip to content

Commit 9b8c0c6

Browse files
committed
fix style and docstrings
1 parent 762dbd5 commit 9b8c0c6

File tree

2 files changed

+2
-4
lines changed

2 files changed

+2
-4
lines changed

nanotabpfn/evaluation.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch
66
from openml.config import set_root_cache_directory
77
from openml.tasks import TaskType
8+
from sklearn.metrics import balanced_accuracy_score, roc_auc_score, r2_score
89
from sklearn.preprocessing import LabelEncoder
910

1011
from nanotabpfn.interface import NanoTabPFNRegressor, NanoTabPFNClassifier
@@ -154,11 +155,9 @@ def get_openml_predictions(
154155

155156
for dataset_name, (y_true, y_pred, y_proba) in predictions.items():
156157
if args.model_type == "classification":
157-
from sklearn.metrics import roc_auc_score, balanced_accuracy_score
158158
acc = balanced_accuracy_score(y_true, y_pred)
159159
auc = roc_auc_score(y_true, y_proba, multi_class='ovr')
160160
print(f"Dataset: {dataset_name} | ROC AUC: {auc:.4f} | Balanced Accuracy: {acc:.4f}")
161161
else:
162-
from sklearn.metrics import r2_score
163162
r2 = r2_score(y_true, y_pred)
164163
print(f"Dataset: {dataset_name} | R2: {r2:.4f}")

nanotabpfn/interface.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,7 @@ def init_model_from_state_dict_file(file_path):
3838

3939
def get_feature_preprocessor(X: ndarray | pd.DataFrame) -> ColumnTransformer:
4040
"""
41-
fits a preprocessor that replaces NaNs with the mean of the respective column
42-
and scales each column to mean 0 and variance 1
41+
fits a preprocessor that imputes NaNs
4342
"""
4443
X = pd.DataFrame(X)
4544
num_mask = []

0 commit comments

Comments
 (0)