Skip to content

Commit 1c8e101

Browse files
Merge pull request #88 from daisybio/feature_loading
Feature loading from_csv to_csv for FeatureDataset
2 parents ad24b3b + d1fe923 commit 1c8e101

File tree

9 files changed

+180
-128
lines changed

9 files changed

+180
-128
lines changed

drevalpy/datasets/dataset.py

Lines changed: 57 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
import copy
1717
import os
18-
from abc import ABC, abstractmethod
1918
from pathlib import Path
2019
from typing import Any, Callable
2120

@@ -31,30 +30,7 @@
3130
np.set_printoptions(threshold=6)
3231

3332

34-
class Dataset(ABC):
35-
"""Abstract wrapper class for datasets."""
36-
37-
@classmethod
38-
@abstractmethod
39-
def from_csv(cls: type["Dataset"], input_file: str | Path, dataset_name: str = "unknown") -> "Dataset":
40-
"""
41-
Loads the dataset from data.
42-
43-
:param input_file: Path to the csv file containing the data to be loaded
44-
:param dataset_name: Optional name to associate the dataset with, default = "unknown"
45-
:returns: Dataset object containing data from provided csv file.
46-
"""
47-
48-
@abstractmethod
49-
def save(self, path: str):
50-
"""
51-
Saves the dataset to data.
52-
53-
:param path: path to the dataset
54-
"""
55-
56-
57-
class DrugResponseDataset(Dataset):
33+
class DrugResponseDataset:
5834
"""Drug response dataset."""
5935

6036
_response: np.ndarray
@@ -226,7 +202,7 @@ def to_dataframe(self) -> pd.DataFrame:
226202
data["predictions"] = self.predictions
227203
return pd.DataFrame(data)
228204

229-
def save(self, path: str | Path):
205+
def to_csv(self, path: str | Path):
230206
"""
231207
Stores the drug response dataset on disk.
232208
@@ -412,7 +388,7 @@ def save_splits(self, path: str):
412388
]:
413389
if mode in split:
414390
split_path = os.path.join(path, f"cv_split_{i}_{mode}.csv")
415-
split[mode].save(path=split_path)
391+
split[mode].to_csv(path=split_path)
416392

417393
def load_splits(self, path: str) -> None:
418394
"""
@@ -720,25 +696,70 @@ def _leave_group_out_cv(
720696
return cv_sets
721697

722698

723-
class FeatureDataset(Dataset):
699+
class FeatureDataset:
724700
"""Class for feature datasets."""
725701

726702
_meta_info: dict[str, Any] = {}
727703
_features: dict[str, dict[str, Any]] = {}
728704

729705
@classmethod
730706
def from_csv(
731-
cls: type["FeatureDataset"], input_file: str | Path, dataset_name: str = "unknown"
732-
) -> "FeatureDataset":
707+
cls: type["FeatureDataset"],
708+
path_to_csv: str | Path,
709+
id_column: str,
710+
view_name: str,
711+
drop_columns: list[str] | None = None,
712+
):
713+
"""Load a one-view feature dataset from a csv file.
714+
715+
Load a feature dataset from a csv file. The rows of the csv file represent the instances (cell lines or drugs),
716+
the columns represent the features. A column named id_column contains the identifiers of the instances.
717+
All unrelated columns (e.g. other id columns) should be provided as drop_columns,
718+
that will be removed from the dataset.
719+
720+
:param path_to_csv: path to the csv file containing the data to be loaded
721+
:param view_name: name of the view (e.g. gene_expression)
722+
:param id_column: name of the column containing the identifiers
723+
:param drop_columns: list of columns to drop (e.g. other identifier columns)
724+
:returns: FeatureDataset object containing data from provided csv file.
725+
"""
726+
data = pd.read_csv(path_to_csv)
727+
ids = data[id_column].values
728+
data_features = data.drop(columns=(drop_columns or []))
729+
data_features = data_features.set_index(id_column)
730+
# remove duplicate feature rows (rows with the same index)
731+
data_features = data_features[~data_features.index.duplicated(keep="first")]
732+
features = {}
733+
734+
for identifier in ids:
735+
features_for_instance = data_features.loc[identifier].values
736+
features[identifier] = {view_name: features_for_instance}
737+
738+
return cls(features=features)
739+
740+
def to_csv(self, path: str | Path, id_column: str, view_name: str):
733741
"""
734-
Load a feature dataset from a csv file.
742+
Save the feature dataset to a CSV file.
735743
736-
This function creates a FeatureDataset from a provided input file in csv format.
737-
:param input_file: Path to the csv file containing the data to be loaded
738-
:param dataset_name: Optional name to associate the dataset with, default = "unknown"
739-
:raises NotImplementedError: This method is currently not implemented.
744+
:param path: Path to the CSV file.
745+
:param id_column: Name of the column containing the identifiers.
746+
:param view_name: Name of the view (e.g., gene_expression).
747+
748+
:raises ValueError: If the view is not found for an identifier.
740749
"""
741-
raise NotImplementedError
750+
data = []
751+
for identifier, feature_dict in self.features.items():
752+
# Get the feature vector for the specified view
753+
if view_name in feature_dict:
754+
row = {id_column: identifier}
755+
row.update({f"feature_{i}": value for i, value in enumerate(feature_dict[view_name])})
756+
data.append(row)
757+
else:
758+
raise ValueError(f"View {view_name!r} not found for identifier {identifier!r}.")
759+
760+
# Convert to DataFrame and save to CSV
761+
df = pd.DataFrame(data)
762+
df.to_csv(path, index=False)
742763

743764
@property
744765
def meta_info(self) -> dict[str, Any]:
@@ -798,15 +819,6 @@ def __init__(
798819
raise AssertionError(f"Meta keys {meta_info.keys()} not in view names {self.view_names}")
799820
self._meta_info = meta_info
800821

801-
def save(self, path: str):
802-
"""
803-
Saves the feature dataset to data.
804-
805-
:param path: path to the dataset
806-
:raises NotImplementedError: if method is not implemented
807-
"""
808-
raise NotImplementedError("save method not implemented")
809-
810822
def randomize_features(self, views_to_randomize: str | list[str], randomization_type: str) -> None:
811823
"""
812824
Randomizes the feature vectors.

0 commit comments

Comments
 (0)