From c6906f6f82fe8a998f916aacd6022af674aea9d5 Mon Sep 17 00:00:00 2001
From: PascalIversen
Date: Thu, 19 Jun 2025 11:39:57 +0200
Subject: [PATCH 1/7] drug_features.copy()
---
drevalpy/experiment.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/drevalpy/experiment.py b/drevalpy/experiment.py
index acdfcd12..e9f1e55f 100644
--- a/drevalpy/experiment.py
+++ b/drevalpy/experiment.py
@@ -1358,10 +1358,11 @@ def train_final_model(
train_dataset.shuffle(random_state=42)
model.build_model(hyperparameters=best_hpams)
+ drug_features = drug_features.copy() if drug_features is not None else None
model.train(
output=train_dataset,
output_earlystopping=early_stopping_dataset,
- cell_line_input=cl_features,
+ cell_line_input=cl_features.copy(),
drug_input=drug_features,
model_checkpoint_dir=model_checkpoint_dir,
)
From 3f12cbb7be3c9c73e670b4d079ed7ea643f7dc28 Mon Sep 17 00:00:00 2001
From: PascalIversen
Date: Thu, 19 Jun 2025 11:43:51 +0200
Subject: [PATCH 2/7] final_model_path
---
drevalpy/experiment.py | 7 +++----
1 file changed, 3 insertions(+), 4 deletions(-)
diff --git a/drevalpy/experiment.py b/drevalpy/experiment.py
index e9f1e55f..e61ab973 100644
--- a/drevalpy/experiment.py
+++ b/drevalpy/experiment.py
@@ -302,7 +302,7 @@ def drug_response_experiment(
path_data=path_data,
model_checkpoint_dir=model_checkpoint_dir,
metric=metric,
- result_path=final_model_path,
+ final_model_path=final_model_path,
test_mode=test_mode,
val_ratio=0.1,
hyperparameter_tuning=hyperparameter_tuning,
@@ -1291,7 +1291,7 @@ def train_final_model(
path_data: str,
model_checkpoint_dir: str,
metric: str,
- result_path: str,
+ final_model_path: str,
test_mode: str = "LCO",
val_ratio: float = 0.1,
hyperparameter_tuning: bool = True,
@@ -1314,7 +1314,7 @@ def train_final_model(
:param path_data: path to data directory
:param model_checkpoint_dir: checkpoint dir for intermediate tuning models
:param metric: metric for tuning, e.g., "RMSE"
- :param result_path: path to results
+ :param final_model_path: path to final_model save directory
:param test_mode: split logic for validation (LCO, LDO, LTO, LPO)
:param val_ratio: validation size ratio
:param hyperparameter_tuning: whether to perform hyperparameter tuning
@@ -1367,7 +1367,6 @@ def train_final_model(
model_checkpoint_dir=model_checkpoint_dir,
)
- final_model_path = os.path.join(result_path, "final_model")
os.makedirs(final_model_path, exist_ok=True)
model.save(final_model_path)
From 6d7ced171a7c06a70a3845d7bc76e4468e3b1e90 Mon Sep 17 00:00:00 2001
From: PascalIversen
Date: Thu, 19 Jun 2025 12:01:56 +0200
Subject: [PATCH 3/7] added to_csv metainfo save and corresponding tests
adapted
---
drevalpy/datasets/dataset.py | 38 ++++++++-----
tests/test_dataset.py | 103 ++++++++++++++++++++---------------
2 files changed, 85 insertions(+), 56 deletions(-)
diff --git a/drevalpy/datasets/dataset.py b/drevalpy/datasets/dataset.py
index e8ddbf2f..3150eee9 100644
--- a/drevalpy/datasets/dataset.py
+++ b/drevalpy/datasets/dataset.py
@@ -817,6 +817,7 @@ def from_csv(
view_name: str,
drop_columns: list[str] | None = None,
transpose: bool = False,
+ extract_meta_info: bool = True,
):
"""Load a one-view feature dataset from a csv file.
@@ -830,6 +831,7 @@ def from_csv(
:param id_column: name of the column containing the identifiers
:param drop_columns: list of columns to drop (e.g. other identifier columns)
:param transpose: if True, the csv is transposed, i.e. the rows become columns and vice versa
+ :param extract_meta_info: if True, extracts meta information from the dataset, e.g. gene names for gene expression
:returns: FeatureDataset object containing data from provided csv file.
"""
data = pd.read_csv(path_to_csv).T if transpose else pd.read_csv(path_to_csv)
@@ -837,7 +839,6 @@ def from_csv(
ids = data[id_column].values
data_features = data.drop(columns=(drop_columns or []))
data_features = data_features.set_index(id_column)
- # remove duplicate feature rows (rows with the same index)
data_features = data_features[~data_features.index.duplicated(keep="first")]
features = {}
@@ -845,29 +846,40 @@ def from_csv(
features_for_instance = data_features.loc[identifier].values
features[identifier] = {view_name: features_for_instance}
- return cls(features=features)
+ meta_info = None
+ if extract_meta_info:
+ meta_info = {view_name: list(data_features.columns)}
+
+ return cls(features=features, meta_info=meta_info)
def to_csv(self, path: str | Path, id_column: str, view_name: str):
"""
- Save the feature dataset to a CSV file.
+ Save the feature dataset to a CSV file. If meta_info is available for the view and valid,
+ it will be written as column names.
:param path: Path to the CSV file.
:param id_column: Name of the column containing the identifiers.
- :param view_name: Name of the view (e.g., gene_expression).
-
- :raises ValueError: If the view is not found for an identifier.
+ :param view_name: Name of the view.
"""
data = []
+ feature_names = None
+
for identifier, feature_dict in self.features.items():
- # Get the feature vector for the specified view
- if view_name in feature_dict:
- row = {id_column: identifier}
- row.update({f"feature_{i}": value for i, value in enumerate(feature_dict[view_name])})
- data.append(row)
- else:
+ vector = feature_dict.get(view_name)
+ if vector is None:
raise ValueError(f"View {view_name!r} not found for identifier {identifier!r}.")
- # Convert to DataFrame and save to CSV
+ if feature_names is None:
+ meta_names = self.meta_info.get(view_name)
+ if isinstance(meta_names, list) and len(meta_names) == len(vector):
+ feature_names = meta_names
+ else:
+ feature_names = [f"feature_{i}" for i in range(len(vector))]
+
+ row = {id_column: identifier}
+ row.update({name: value for name, value in zip(feature_names, vector)})
+ data.append(row)
+
df = pd.DataFrame(data)
df.to_csv(path, index=False)
diff --git a/tests/test_dataset.py b/tests/test_dataset.py
index 5c0181d6..0162bfb9 100644
--- a/tests/test_dataset.py
+++ b/tests/test_dataset.py
@@ -591,59 +591,76 @@ def test_add_features(sample_dataset: FeatureDataset, graph_dataset: FeatureData
assert "molecular_graph" in sample_dataset.view_names
-def test_feature_dataset_csv_methods():
- """Test the `from_csv` and `to_csv` methods of the FeatureDataset class."""
- # Create temporary directory for testing
+def test_feature_dataset_csv_meta_handling():
+ """Test `from_csv` and `to_csv` methods with and without meta_info handling."""
with tempfile.TemporaryDirectory() as temp_dir:
temp_dir = Path(temp_dir)
- # Create test data
- test_csv_path = temp_dir / "test_features.csv"
- data = {
- "id": ["A", "B", "C"],
- "feature_1": [1.0, 2.0, 3.0],
- "feature_2": [4.0, 5.0, 6.0],
- }
- df = pd.DataFrame(data)
- df.to_csv(test_csv_path, index=False)
-
- # Test `from_csv` method
+ # ------------------------------------
+ # 0. Create initial test DataFrame/CSV
+ # ------------------------------------
+ df_with_named_cols = pd.DataFrame(
+ {
+ "id": ["A", "B", "C"],
+ "feature_1": [1.0, 2.0, 3.0],
+ "feature_2": [4.0, 5.0, 6.0],
+ }
+ )
+ csv_with_meta = temp_dir / "input_with_meta.csv"
+ df_with_named_cols.to_csv(csv_with_meta, index=False)
+
view_name = "example_view"
- feature_dataset = FeatureDataset.from_csv(
- path_to_csv=test_csv_path, id_column="id", view_name=view_name, drop_columns=None
+
+ # ------------------------------------
+ # 1. Load from CSV → should extract meta_info
+ # ------------------------------------
+ dataset = FeatureDataset.from_csv(
+ path_to_csv=csv_with_meta,
+ id_column="id",
+ view_name=view_name,
)
- # Validate loaded data
- assert set(feature_dataset.identifiers) == {"A", "B", "C"}, "Identifiers mismatch."
- assert feature_dataset.view_names == [view_name], "View names mismatch."
- expected_features = {
- "A": {"example_view": np.array([1.0, 4.0])},
- "B": {"example_view": np.array([2.0, 5.0])},
- "C": {"example_view": np.array([3.0, 6.0])},
- }
- for identifier in expected_features:
- np.testing.assert_array_equal(
- feature_dataset.features[identifier][view_name],
- expected_features[identifier][view_name],
- f"Feature mismatch for identifier {identifier}.",
- )
-
- # Test `to_csv` method
- output_csv_path = temp_dir / "output_features.csv"
- feature_dataset.to_csv(path=output_csv_path, id_column="id", view_name=view_name)
-
- # Validate saved data
- saved_df = pd.read_csv(output_csv_path)
- expected_saved_df = pd.DataFrame(
+ assert dataset.meta_info == {view_name: ["feature_1", "feature_2"]}
+ assert set(dataset.identifiers) == {"A", "B", "C"}
+ assert dataset.view_names == [view_name]
+
+ # ------------------------------------
+ # 2. Save with meta_info → column names should be preserved
+ # ------------------------------------
+ csv_out_with_meta = temp_dir / "saved_with_meta.csv"
+ dataset.to_csv(csv_out_with_meta, id_column="id", view_name=view_name)
+
+ saved_df = pd.read_csv(csv_out_with_meta)
+ pd.testing.assert_frame_equal(saved_df, df_with_named_cols, check_dtype=False)
+
+ # ------------------------------------
+ # 3. Save without meta_info → fallback to generic feature_0, feature_1
+ # ------------------------------------
+ dataset._meta_info = {} # simulate no meta info
+ csv_out_no_meta = temp_dir / "saved_no_meta.csv"
+ dataset.to_csv(csv_out_no_meta, id_column="id", view_name=view_name)
+
+ df_fallback = pd.DataFrame(
{
"id": ["A", "B", "C"],
"feature_0": [1.0, 2.0, 3.0],
"feature_1": [4.0, 5.0, 6.0],
}
)
- pd.testing.assert_frame_equal(
- saved_df,
- expected_saved_df,
- check_dtype=False, # Relax dtype check for cross-platform compatibility
- obj="Saved CSV data",
+ saved_fallback_df = pd.read_csv(csv_out_no_meta)
+ pd.testing.assert_frame_equal(saved_fallback_df, df_fallback, check_dtype=False)
+
+ # ------------------------------------
+ # 4. Load fallback CSV → should reconstruct generic meta_info
+ # ------------------------------------
+ dataset_fallback = FeatureDataset.from_csv(
+ path_to_csv=csv_out_no_meta,
+ id_column="id",
+ view_name=view_name,
+ )
+
+ assert dataset_fallback.meta_info == {view_name: ["feature_0", "feature_1"]}
+ np.testing.assert_array_equal(
+ dataset_fallback.features["B"][view_name],
+ np.array([2.0, 5.0]),
)
From a4ea7ac149e00e57ed1b7c3e81c949159577442e Mon Sep 17 00:00:00 2001
From: PascalIversen
Date: Thu, 19 Jun 2025 12:06:54 +0200
Subject: [PATCH 4/7] final model for more test coverage
---
tests/test_run_suite.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/tests/test_run_suite.py b/tests/test_run_suite.py
index 8583d80f..6b1d7524 100644
--- a/tests/test_run_suite.py
+++ b/tests/test_run_suite.py
@@ -42,7 +42,7 @@
"multiprocessing": False,
"path_data": "../data",
"model_checkpoint_dir": "TEMPORARY",
- "final_model_on_full_data": False,
+ "final_model_on_full_data": True,
}
],
)
From 15c6fe7a0f4d0d0a0b672fbecd2c735228d0c142 Mon Sep 17 00:00:00 2001
From: PascalIversen
Date: Thu, 19 Jun 2025 12:11:02 +0200
Subject: [PATCH 5/7] default dict
---
drevalpy/datasets/dataset.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/drevalpy/datasets/dataset.py b/drevalpy/datasets/dataset.py
index 3150eee9..e306370a 100644
--- a/drevalpy/datasets/dataset.py
+++ b/drevalpy/datasets/dataset.py
@@ -846,7 +846,7 @@ def from_csv(
features_for_instance = data_features.loc[identifier].values
features[identifier] = {view_name: features_for_instance}
- meta_info = None
+ meta_info = {}
if extract_meta_info:
meta_info = {view_name: list(data_features.columns)}
From a061d15e7ef2778c1b960bfe863ab9183901c0ea Mon Sep 17 00:00:00 2001
From: PascalIversen
Date: Thu, 19 Jun 2025 20:26:55 +0200
Subject: [PATCH 6/7] added response transformation to test and better none
typing
---
drevalpy/utils.py | 4 ++--
tests/test_run_suite.py | 2 +-
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/drevalpy/utils.py b/drevalpy/utils.py
index c21242f0..4695b44f 100644
--- a/drevalpy/utils.py
+++ b/drevalpy/utils.py
@@ -392,7 +392,7 @@ def get_datasets(
@pipeline_function
-def get_response_transformation(response_transformation: str) -> TransformerMixin | None:
+def get_response_transformation(response_transformation: str | None) -> TransformerMixin | None:
"""
Get the skelarn response transformation object of choice.
@@ -401,7 +401,7 @@ def get_response_transformation(response_transformation: str) -> TransformerMixi
:returns: response transformation object
:raises ValueError: if the response transformation is not recognized
"""
- if response_transformation == "None":
+ if (response_transformation == "None") or (response_transformation is None):
return None
if response_transformation == "standard":
return StandardScaler()
diff --git a/tests/test_run_suite.py b/tests/test_run_suite.py
index 6b1d7524..eca8ad8d 100644
--- a/tests/test_run_suite.py
+++ b/tests/test_run_suite.py
@@ -38,7 +38,7 @@
"overwrite": False,
"optim_metric": "RMSE",
"n_cv_splits": 2,
- "response_transformation": "None",
+ "response_transformation": "standard",
"multiprocessing": False,
"path_data": "../data",
"model_checkpoint_dir": "TEMPORARY",
From d2c6037b00dec6ffe7c0183c6ea9f4bad627aef6 Mon Sep 17 00:00:00 2001
From: Judith Bernett
Date: Fri, 20 Jun 2025 17:26:46 +0200
Subject: [PATCH 7/7] closes #236 , #237 : train_final_model is no pipeline
function. Added inverse transform in train_and_predict, fixed wrong response
transformation in cross-study, added response_transformation+inverse
transform in train_final_model
---
drevalpy/experiment.py | 21 ++++++++++++++++-----
drevalpy/visualization/utils.py | 11 ++++++-----
tests/test_run_suite.py | 2 +-
3 files changed, 23 insertions(+), 11 deletions(-)
diff --git a/drevalpy/experiment.py b/drevalpy/experiment.py
index e61ab973..8c83fc98 100644
--- a/drevalpy/experiment.py
+++ b/drevalpy/experiment.py
@@ -585,7 +585,7 @@ def cross_study_prediction(
drug_input=drug_input,
)
if response_transformation:
- dataset._response = response_transformation.inverse_transform(dataset.response)
+ dataset.inverse_transform(response_transformation)
else:
dataset._predictions = np.array([])
dataset.to_csv(
@@ -1000,11 +1000,15 @@ def train_and_predict(
drug_input=drug_input,
)
- if response_transformation:
- prediction_dataset.inverse_transform(response_transformation)
else:
prediction_dataset._predictions = np.array([])
+ if response_transformation:
+ train_dataset.inverse_transform(response_transformation)
+ prediction_dataset.inverse_transform(response_transformation)
+ if early_stopping_dataset is not None:
+ early_stopping_dataset.inverse_transform(response_transformation)
+
return prediction_dataset
@@ -1016,7 +1020,7 @@ def train_and_evaluate(
validation_dataset: DrugResponseDataset,
early_stopping_dataset: DrugResponseDataset | None = None,
response_transformation: TransformerMixin | None = None,
- metric: str = "rmse",
+ metric: str = "RMSE",
model_checkpoint_dir: str = "TEMPORARY",
) -> dict[str, float]:
"""
@@ -1283,7 +1287,6 @@ def generate_data_saving_path(model_name, drug_id, result_path, suffix) -> str:
return model_path
-@pipeline_function
def train_final_model(
model_class: type[DRPModel],
full_dataset: DrugResponseDataset,
@@ -1356,6 +1359,10 @@ def train_final_model(
print(f"Best hyperparameters for final model: {best_hpams}")
train_dataset.add_rows(validation_dataset)
train_dataset.shuffle(random_state=42)
+ if response_transformation:
+ train_dataset.fit_transform(response_transformation)
+ if early_stopping_dataset is not None:
+ early_stopping_dataset.transform(response_transformation)
model.build_model(hyperparameters=best_hpams)
drug_features = drug_features.copy() if drug_features is not None else None
@@ -1366,6 +1373,10 @@ def train_final_model(
drug_input=drug_features,
model_checkpoint_dir=model_checkpoint_dir,
)
+ if response_transformation:
+ train_dataset.inverse_transform(response_transformation)
+ if early_stopping_dataset is not None:
+ early_stopping_dataset.inverse_transform(response_transformation)
os.makedirs(final_model_path, exist_ok=True)
model.save(final_model_path)
diff --git a/drevalpy/visualization/utils.py b/drevalpy/visualization/utils.py
index 428c108f..58e78961 100644
--- a/drevalpy/visualization/utils.py
+++ b/drevalpy/visualization/utils.py
@@ -12,6 +12,7 @@
from ..datasets.dataset import DrugResponseDataset
from ..evaluation import AVAILABLE_METRICS, evaluate
+from ..models.utils import CELL_LINE_IDENTIFIER, DRUG_IDENTIFIER
from ..pipeline_function import pipeline_function
from . import (
ComparisonScatter,
@@ -228,7 +229,7 @@ def prep_results(
elif file == "cell_line_names.csv":
cell_line_names = pd.read_csv(os.path.join(root, file), index_col=0)
# index: cellosaurus_id, column: cell_line_name
- cell_line_metadata.update(zip(cell_line_names["cell_line_name"], cell_line_names.index))
+ cell_line_metadata.update(zip(cell_line_names[CELL_LINE_IDENTIFIER], cell_line_names.index))
# add variables
# split the index by "_" into: algorithm, randomization, test_mode, split, CV_split
@@ -251,7 +252,7 @@ def prep_results(
all_drugs = [drug_metadata[drug] for drug in eval_results_per_drug["drug"]]
eval_results_per_drug["drug_name"] = all_drugs
# rename drug to pubchem_id
- eval_results_per_drug = eval_results_per_drug.rename(columns={"drug": "pubchem_id"})
+ eval_results_per_drug = eval_results_per_drug.rename(columns={"drug": DRUG_IDENTIFIER})
if eval_results_per_cell_line is not None:
print("Reformatting the evaluation results per cell line ...")
eval_results_per_cell_line[["algorithm", "rand_setting", "test_mode", "split", "CV_split"]] = (
@@ -259,7 +260,7 @@ def prep_results(
)
all_cello_ids = [cell_line_metadata[cell_line] for cell_line in eval_results_per_cell_line["cell_line"]]
eval_results_per_cell_line["cellosaurus_id"] = all_cello_ids
- eval_results_per_cell_line = eval_results_per_cell_line.rename(columns={"cell_line": "cell_line_name"})
+ eval_results_per_cell_line = eval_results_per_cell_line.rename(columns={"cell_line": CELL_LINE_IDENTIFIER})
print("Reformatting the true vs. predicted values ...")
t_vs_p[["algorithm", "rand_setting", "test_mode", "split", "CV_split"]] = t_vs_p["model"].str.split(
@@ -270,8 +271,8 @@ def prep_results(
t_vs_p["drug_name"] = all_drugs
all_cello_ids = [cell_line_metadata[cell_line] for cell_line in t_vs_p["cell_line"]]
t_vs_p["cellosaurus_id"] = all_cello_ids
- t_vs_p = t_vs_p.rename(columns={"cell_line": "cell_line_name", "drug": "pubchem_id"})
- t_vs_p["pubchem_id"] = t_vs_p["pubchem_id"].astype(str)
+ t_vs_p = t_vs_p.rename(columns={"cell_line": CELL_LINE_IDENTIFIER, "drug": DRUG_IDENTIFIER})
+ t_vs_p[DRUG_IDENTIFIER] = t_vs_p[DRUG_IDENTIFIER].astype(str)
if "NaiveMeanEffectsPredictor" in eval_results["algorithm"].unique():
eval_results = _normalize_metrics_by_mean_effects(
diff --git a/tests/test_run_suite.py b/tests/test_run_suite.py
index eca8ad8d..e6f2ff1b 100644
--- a/tests/test_run_suite.py
+++ b/tests/test_run_suite.py
@@ -42,6 +42,7 @@
"multiprocessing": False,
"path_data": "../data",
"model_checkpoint_dir": "TEMPORARY",
+ "no_hyperparameter_tuning": True,
"final_model_on_full_data": True,
}
],
@@ -55,7 +56,6 @@ def test_run_suite(args):
temp_dir = tempfile.TemporaryDirectory()
args["path_out"] = temp_dir.name
args = Namespace(**args)
- args.no_hyperparameter_tuning = True
get_parser()
check_arguments(args)
main(args)