diff --git a/drevalpy/datasets/dataset.py b/drevalpy/datasets/dataset.py index e8ddbf2f..e306370a 100644 --- a/drevalpy/datasets/dataset.py +++ b/drevalpy/datasets/dataset.py @@ -817,6 +817,7 @@ def from_csv( view_name: str, drop_columns: list[str] | None = None, transpose: bool = False, + extract_meta_info: bool = True, ): """Load a one-view feature dataset from a csv file. @@ -830,6 +831,7 @@ def from_csv( :param id_column: name of the column containing the identifiers :param drop_columns: list of columns to drop (e.g. other identifier columns) :param transpose: if True, the csv is transposed, i.e. the rows become columns and vice versa + :param extract_meta_info: if True, extracts meta information from the dataset, e.g. gene names for gene expression :returns: FeatureDataset object containing data from provided csv file. """ data = pd.read_csv(path_to_csv).T if transpose else pd.read_csv(path_to_csv) @@ -837,7 +839,6 @@ def from_csv( ids = data[id_column].values data_features = data.drop(columns=(drop_columns or [])) data_features = data_features.set_index(id_column) - # remove duplicate feature rows (rows with the same index) data_features = data_features[~data_features.index.duplicated(keep="first")] features = {} @@ -845,29 +846,40 @@ def from_csv( features_for_instance = data_features.loc[identifier].values features[identifier] = {view_name: features_for_instance} - return cls(features=features) + meta_info = {} + if extract_meta_info: + meta_info = {view_name: list(data_features.columns)} + + return cls(features=features, meta_info=meta_info) def to_csv(self, path: str | Path, id_column: str, view_name: str): """ - Save the feature dataset to a CSV file. + Save the feature dataset to a CSV file. If meta_info is available for the view and valid, + it will be written as column names. :param path: Path to the CSV file. :param id_column: Name of the column containing the identifiers. - :param view_name: Name of the view (e.g., gene_expression). - - :raises ValueError: If the view is not found for an identifier. + :param view_name: Name of the view. """ data = [] + feature_names = None + for identifier, feature_dict in self.features.items(): - # Get the feature vector for the specified view - if view_name in feature_dict: - row = {id_column: identifier} - row.update({f"feature_{i}": value for i, value in enumerate(feature_dict[view_name])}) - data.append(row) - else: + vector = feature_dict.get(view_name) + if vector is None: raise ValueError(f"View {view_name!r} not found for identifier {identifier!r}.") - # Convert to DataFrame and save to CSV + if feature_names is None: + meta_names = self.meta_info.get(view_name) + if isinstance(meta_names, list) and len(meta_names) == len(vector): + feature_names = meta_names + else: + feature_names = [f"feature_{i}" for i in range(len(vector))] + + row = {id_column: identifier} + row.update({name: value for name, value in zip(feature_names, vector)}) + data.append(row) + df = pd.DataFrame(data) df.to_csv(path, index=False) diff --git a/drevalpy/experiment.py b/drevalpy/experiment.py index acdfcd12..8c83fc98 100644 --- a/drevalpy/experiment.py +++ b/drevalpy/experiment.py @@ -302,7 +302,7 @@ def drug_response_experiment( path_data=path_data, model_checkpoint_dir=model_checkpoint_dir, metric=metric, - result_path=final_model_path, + final_model_path=final_model_path, test_mode=test_mode, val_ratio=0.1, hyperparameter_tuning=hyperparameter_tuning, @@ -585,7 +585,7 @@ def cross_study_prediction( drug_input=drug_input, ) if response_transformation: - dataset._response = response_transformation.inverse_transform(dataset.response) + dataset.inverse_transform(response_transformation) else: dataset._predictions = np.array([]) dataset.to_csv( @@ -1000,11 +1000,15 @@ def train_and_predict( drug_input=drug_input, ) - if response_transformation: - prediction_dataset.inverse_transform(response_transformation) else: prediction_dataset._predictions = np.array([]) + if response_transformation: + train_dataset.inverse_transform(response_transformation) + prediction_dataset.inverse_transform(response_transformation) + if early_stopping_dataset is not None: + early_stopping_dataset.inverse_transform(response_transformation) + return prediction_dataset @@ -1016,7 +1020,7 @@ def train_and_evaluate( validation_dataset: DrugResponseDataset, early_stopping_dataset: DrugResponseDataset | None = None, response_transformation: TransformerMixin | None = None, - metric: str = "rmse", + metric: str = "RMSE", model_checkpoint_dir: str = "TEMPORARY", ) -> dict[str, float]: """ @@ -1283,7 +1287,6 @@ def generate_data_saving_path(model_name, drug_id, result_path, suffix) -> str: return model_path -@pipeline_function def train_final_model( model_class: type[DRPModel], full_dataset: DrugResponseDataset, @@ -1291,7 +1294,7 @@ def train_final_model( path_data: str, model_checkpoint_dir: str, metric: str, - result_path: str, + final_model_path: str, test_mode: str = "LCO", val_ratio: float = 0.1, hyperparameter_tuning: bool = True, @@ -1314,7 +1317,7 @@ def train_final_model( :param path_data: path to data directory :param model_checkpoint_dir: checkpoint dir for intermediate tuning models :param metric: metric for tuning, e.g., "RMSE" - :param result_path: path to results + :param final_model_path: path to final_model save directory :param test_mode: split logic for validation (LCO, LDO, LTO, LPO) :param val_ratio: validation size ratio :param hyperparameter_tuning: whether to perform hyperparameter tuning @@ -1356,17 +1359,25 @@ def train_final_model( print(f"Best hyperparameters for final model: {best_hpams}") train_dataset.add_rows(validation_dataset) train_dataset.shuffle(random_state=42) + if response_transformation: + train_dataset.fit_transform(response_transformation) + if early_stopping_dataset is not None: + early_stopping_dataset.transform(response_transformation) model.build_model(hyperparameters=best_hpams) + drug_features = drug_features.copy() if drug_features is not None else None model.train( output=train_dataset, output_earlystopping=early_stopping_dataset, - cell_line_input=cl_features, + cell_line_input=cl_features.copy(), drug_input=drug_features, model_checkpoint_dir=model_checkpoint_dir, ) + if response_transformation: + train_dataset.inverse_transform(response_transformation) + if early_stopping_dataset is not None: + early_stopping_dataset.inverse_transform(response_transformation) - final_model_path = os.path.join(result_path, "final_model") os.makedirs(final_model_path, exist_ok=True) model.save(final_model_path) diff --git a/drevalpy/utils.py b/drevalpy/utils.py index c21242f0..4695b44f 100644 --- a/drevalpy/utils.py +++ b/drevalpy/utils.py @@ -392,7 +392,7 @@ def get_datasets( @pipeline_function -def get_response_transformation(response_transformation: str) -> TransformerMixin | None: +def get_response_transformation(response_transformation: str | None) -> TransformerMixin | None: """ Get the skelarn response transformation object of choice. @@ -401,7 +401,7 @@ def get_response_transformation(response_transformation: str) -> TransformerMixi :returns: response transformation object :raises ValueError: if the response transformation is not recognized """ - if response_transformation == "None": + if (response_transformation == "None") or (response_transformation is None): return None if response_transformation == "standard": return StandardScaler() diff --git a/drevalpy/visualization/utils.py b/drevalpy/visualization/utils.py index 428c108f..58e78961 100644 --- a/drevalpy/visualization/utils.py +++ b/drevalpy/visualization/utils.py @@ -12,6 +12,7 @@ from ..datasets.dataset import DrugResponseDataset from ..evaluation import AVAILABLE_METRICS, evaluate +from ..models.utils import CELL_LINE_IDENTIFIER, DRUG_IDENTIFIER from ..pipeline_function import pipeline_function from . import ( ComparisonScatter, @@ -228,7 +229,7 @@ def prep_results( elif file == "cell_line_names.csv": cell_line_names = pd.read_csv(os.path.join(root, file), index_col=0) # index: cellosaurus_id, column: cell_line_name - cell_line_metadata.update(zip(cell_line_names["cell_line_name"], cell_line_names.index)) + cell_line_metadata.update(zip(cell_line_names[CELL_LINE_IDENTIFIER], cell_line_names.index)) # add variables # split the index by "_" into: algorithm, randomization, test_mode, split, CV_split @@ -251,7 +252,7 @@ def prep_results( all_drugs = [drug_metadata[drug] for drug in eval_results_per_drug["drug"]] eval_results_per_drug["drug_name"] = all_drugs # rename drug to pubchem_id - eval_results_per_drug = eval_results_per_drug.rename(columns={"drug": "pubchem_id"}) + eval_results_per_drug = eval_results_per_drug.rename(columns={"drug": DRUG_IDENTIFIER}) if eval_results_per_cell_line is not None: print("Reformatting the evaluation results per cell line ...") eval_results_per_cell_line[["algorithm", "rand_setting", "test_mode", "split", "CV_split"]] = ( @@ -259,7 +260,7 @@ def prep_results( ) all_cello_ids = [cell_line_metadata[cell_line] for cell_line in eval_results_per_cell_line["cell_line"]] eval_results_per_cell_line["cellosaurus_id"] = all_cello_ids - eval_results_per_cell_line = eval_results_per_cell_line.rename(columns={"cell_line": "cell_line_name"}) + eval_results_per_cell_line = eval_results_per_cell_line.rename(columns={"cell_line": CELL_LINE_IDENTIFIER}) print("Reformatting the true vs. predicted values ...") t_vs_p[["algorithm", "rand_setting", "test_mode", "split", "CV_split"]] = t_vs_p["model"].str.split( @@ -270,8 +271,8 @@ def prep_results( t_vs_p["drug_name"] = all_drugs all_cello_ids = [cell_line_metadata[cell_line] for cell_line in t_vs_p["cell_line"]] t_vs_p["cellosaurus_id"] = all_cello_ids - t_vs_p = t_vs_p.rename(columns={"cell_line": "cell_line_name", "drug": "pubchem_id"}) - t_vs_p["pubchem_id"] = t_vs_p["pubchem_id"].astype(str) + t_vs_p = t_vs_p.rename(columns={"cell_line": CELL_LINE_IDENTIFIER, "drug": DRUG_IDENTIFIER}) + t_vs_p[DRUG_IDENTIFIER] = t_vs_p[DRUG_IDENTIFIER].astype(str) if "NaiveMeanEffectsPredictor" in eval_results["algorithm"].unique(): eval_results = _normalize_metrics_by_mean_effects( diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 5c0181d6..0162bfb9 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -591,59 +591,76 @@ def test_add_features(sample_dataset: FeatureDataset, graph_dataset: FeatureData assert "molecular_graph" in sample_dataset.view_names -def test_feature_dataset_csv_methods(): - """Test the `from_csv` and `to_csv` methods of the FeatureDataset class.""" - # Create temporary directory for testing +def test_feature_dataset_csv_meta_handling(): + """Test `from_csv` and `to_csv` methods with and without meta_info handling.""" with tempfile.TemporaryDirectory() as temp_dir: temp_dir = Path(temp_dir) - # Create test data - test_csv_path = temp_dir / "test_features.csv" - data = { - "id": ["A", "B", "C"], - "feature_1": [1.0, 2.0, 3.0], - "feature_2": [4.0, 5.0, 6.0], - } - df = pd.DataFrame(data) - df.to_csv(test_csv_path, index=False) - - # Test `from_csv` method + # ------------------------------------ + # 0. Create initial test DataFrame/CSV + # ------------------------------------ + df_with_named_cols = pd.DataFrame( + { + "id": ["A", "B", "C"], + "feature_1": [1.0, 2.0, 3.0], + "feature_2": [4.0, 5.0, 6.0], + } + ) + csv_with_meta = temp_dir / "input_with_meta.csv" + df_with_named_cols.to_csv(csv_with_meta, index=False) + view_name = "example_view" - feature_dataset = FeatureDataset.from_csv( - path_to_csv=test_csv_path, id_column="id", view_name=view_name, drop_columns=None + + # ------------------------------------ + # 1. Load from CSV → should extract meta_info + # ------------------------------------ + dataset = FeatureDataset.from_csv( + path_to_csv=csv_with_meta, + id_column="id", + view_name=view_name, ) - # Validate loaded data - assert set(feature_dataset.identifiers) == {"A", "B", "C"}, "Identifiers mismatch." - assert feature_dataset.view_names == [view_name], "View names mismatch." - expected_features = { - "A": {"example_view": np.array([1.0, 4.0])}, - "B": {"example_view": np.array([2.0, 5.0])}, - "C": {"example_view": np.array([3.0, 6.0])}, - } - for identifier in expected_features: - np.testing.assert_array_equal( - feature_dataset.features[identifier][view_name], - expected_features[identifier][view_name], - f"Feature mismatch for identifier {identifier}.", - ) - - # Test `to_csv` method - output_csv_path = temp_dir / "output_features.csv" - feature_dataset.to_csv(path=output_csv_path, id_column="id", view_name=view_name) - - # Validate saved data - saved_df = pd.read_csv(output_csv_path) - expected_saved_df = pd.DataFrame( + assert dataset.meta_info == {view_name: ["feature_1", "feature_2"]} + assert set(dataset.identifiers) == {"A", "B", "C"} + assert dataset.view_names == [view_name] + + # ------------------------------------ + # 2. Save with meta_info → column names should be preserved + # ------------------------------------ + csv_out_with_meta = temp_dir / "saved_with_meta.csv" + dataset.to_csv(csv_out_with_meta, id_column="id", view_name=view_name) + + saved_df = pd.read_csv(csv_out_with_meta) + pd.testing.assert_frame_equal(saved_df, df_with_named_cols, check_dtype=False) + + # ------------------------------------ + # 3. Save without meta_info → fallback to generic feature_0, feature_1 + # ------------------------------------ + dataset._meta_info = {} # simulate no meta info + csv_out_no_meta = temp_dir / "saved_no_meta.csv" + dataset.to_csv(csv_out_no_meta, id_column="id", view_name=view_name) + + df_fallback = pd.DataFrame( { "id": ["A", "B", "C"], "feature_0": [1.0, 2.0, 3.0], "feature_1": [4.0, 5.0, 6.0], } ) - pd.testing.assert_frame_equal( - saved_df, - expected_saved_df, - check_dtype=False, # Relax dtype check for cross-platform compatibility - obj="Saved CSV data", + saved_fallback_df = pd.read_csv(csv_out_no_meta) + pd.testing.assert_frame_equal(saved_fallback_df, df_fallback, check_dtype=False) + + # ------------------------------------ + # 4. Load fallback CSV → should reconstruct generic meta_info + # ------------------------------------ + dataset_fallback = FeatureDataset.from_csv( + path_to_csv=csv_out_no_meta, + id_column="id", + view_name=view_name, + ) + + assert dataset_fallback.meta_info == {view_name: ["feature_0", "feature_1"]} + np.testing.assert_array_equal( + dataset_fallback.features["B"][view_name], + np.array([2.0, 5.0]), ) diff --git a/tests/test_run_suite.py b/tests/test_run_suite.py index 8583d80f..e6f2ff1b 100644 --- a/tests/test_run_suite.py +++ b/tests/test_run_suite.py @@ -38,11 +38,12 @@ "overwrite": False, "optim_metric": "RMSE", "n_cv_splits": 2, - "response_transformation": "None", + "response_transformation": "standard", "multiprocessing": False, "path_data": "../data", "model_checkpoint_dir": "TEMPORARY", - "final_model_on_full_data": False, + "no_hyperparameter_tuning": True, + "final_model_on_full_data": True, } ], ) @@ -55,7 +56,6 @@ def test_run_suite(args): temp_dir = tempfile.TemporaryDirectory() args["path_out"] = temp_dir.name args = Namespace(**args) - args.no_hyperparameter_tuning = True get_parser() check_arguments(args) main(args)