Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 57 additions & 45 deletions drevalpy/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import copy
import os
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Callable

Expand All @@ -31,30 +30,7 @@
np.set_printoptions(threshold=6)


class Dataset(ABC):
"""Abstract wrapper class for datasets."""

@classmethod
@abstractmethod
def from_csv(cls: type["Dataset"], input_file: str | Path, dataset_name: str = "unknown") -> "Dataset":
"""
Loads the dataset from data.

:param input_file: Path to the csv file containing the data to be loaded
:param dataset_name: Optional name to associate the dataset with, default = "unknown"
:returns: Dataset object containing data from provided csv file.
"""

@abstractmethod
def save(self, path: str):
"""
Saves the dataset to data.

:param path: path to the dataset
"""


class DrugResponseDataset(Dataset):
class DrugResponseDataset:
"""Drug response dataset."""

_response: np.ndarray
Expand Down Expand Up @@ -226,7 +202,7 @@ def to_dataframe(self) -> pd.DataFrame:
data["predictions"] = self.predictions
return pd.DataFrame(data)

def save(self, path: str | Path):
def to_csv(self, path: str | Path):
"""
Stores the drug response dataset on disk.

Expand Down Expand Up @@ -412,7 +388,7 @@ def save_splits(self, path: str):
]:
if mode in split:
split_path = os.path.join(path, f"cv_split_{i}_{mode}.csv")
split[mode].save(path=split_path)
split[mode].to_csv(path=split_path)

def load_splits(self, path: str) -> None:
"""
Expand Down Expand Up @@ -720,25 +696,70 @@ def _leave_group_out_cv(
return cv_sets


class FeatureDataset(Dataset):
class FeatureDataset:
"""Class for feature datasets."""

_meta_info: dict[str, Any] = {}
_features: dict[str, dict[str, Any]] = {}

@classmethod
def from_csv(
cls: type["FeatureDataset"], input_file: str | Path, dataset_name: str = "unknown"
) -> "FeatureDataset":
cls: type["FeatureDataset"],
path_to_csv: str | Path,
id_column: str,
view_name: str,
drop_columns: list[str] | None = None,
):
"""Load a one-view feature dataset from a csv file.

Load a feature dataset from a csv file. The rows of the csv file represent the instances (cell lines or drugs),
the columns represent the features. A column named id_column contains the identifiers of the instances.
All unrelated columns (e.g. other id columns) should be provided as drop_columns,
that will be removed from the dataset.

:param path_to_csv: path to the csv file containing the data to be loaded
:param view_name: name of the view (e.g. gene_expression)
:param id_column: name of the column containing the identifiers
:param drop_columns: list of columns to drop (e.g. other identifier columns)
:returns: FeatureDataset object containing data from provided csv file.
"""
data = pd.read_csv(path_to_csv)
ids = data[id_column].values
data_features = data.drop(columns=(drop_columns or []))
data_features = data_features.set_index(id_column)
# remove duplicate feature rows (rows with the same index)
data_features = data_features[~data_features.index.duplicated(keep="first")]
features = {}

for identifier in ids:
features_for_instance = data_features.loc[identifier].values
features[identifier] = {view_name: features_for_instance}

return cls(features=features)

def to_csv(self, path: str | Path, id_column: str, view_name: str):
"""
Load a feature dataset from a csv file.
Save the feature dataset to a CSV file.

This function creates a FeatureDataset from a provided input file in csv format.
:param input_file: Path to the csv file containing the data to be loaded
:param dataset_name: Optional name to associate the dataset with, default = "unknown"
:raises NotImplementedError: This method is currently not implemented.
:param path: Path to the CSV file.
:param id_column: Name of the column containing the identifiers.
:param view_name: Name of the view (e.g., gene_expression).

:raises ValueError: If the view is not found for an identifier.
"""
raise NotImplementedError
data = []
for identifier, feature_dict in self.features.items():
# Get the feature vector for the specified view
if view_name in feature_dict:
row = {id_column: identifier}
row.update({f"feature_{i}": value for i, value in enumerate(feature_dict[view_name])})
data.append(row)
else:
raise ValueError(f"View {view_name!r} not found for identifier {identifier!r}.")

# Convert to DataFrame and save to CSV
df = pd.DataFrame(data)
df.to_csv(path, index=False)

@property
def meta_info(self) -> dict[str, Any]:
Expand Down Expand Up @@ -798,15 +819,6 @@ def __init__(
raise AssertionError(f"Meta keys {meta_info.keys()} not in view names {self.view_names}")
self._meta_info = meta_info

def save(self, path: str):
"""
Saves the feature dataset to data.

:param path: path to the dataset
:raises NotImplementedError: if method is not implemented
"""
raise NotImplementedError("save method not implemented")

def randomize_features(self, views_to_randomize: str | list[str], randomization_type: str) -> None:
"""
Randomizes the feature vectors.
Expand Down
Loading
Loading