99from ..pipeline_function import pipeline_function
1010from .curvecurator import fit_curves
1111from .dataset import DrugResponseDataset
12- from .utils import ALLOWED_MEASURES , CELL_LINE_IDENTIFIER , DRUG_IDENTIFIER , download_dataset
12+ from .utils import ALLOWED_MEASURES , CELL_LINE_IDENTIFIER , DRUG_IDENTIFIER , TISSUE_IDENTIFIER , download_dataset
1313
1414
1515def check_measure (measure_queried : str , measures_data : list [str ], dataset_name : str ) -> None :
@@ -56,6 +56,7 @@ def load_gdsc1(
5656 response = response_data [measure ].values ,
5757 cell_line_ids = response_data [CELL_LINE_IDENTIFIER ].values ,
5858 drug_ids = response_data [DRUG_IDENTIFIER ].values ,
59+ tissues = response_data [TISSUE_IDENTIFIER ].values ,
5960 dataset_name = dataset_name ,
6061 )
6162
@@ -97,6 +98,7 @@ def load_ccle(
9798 response = response_data [measure ].values ,
9899 cell_line_ids = response_data [CELL_LINE_IDENTIFIER ].values ,
99100 drug_ids = response_data [DRUG_IDENTIFIER ].values ,
101+ tissues = response_data [TISSUE_IDENTIFIER ].values ,
100102 dataset_name = dataset_name ,
101103 )
102104
@@ -122,6 +124,7 @@ def _load_toy(
122124 response = response_data [measure ].values ,
123125 cell_line_ids = response_data [CELL_LINE_IDENTIFIER ].values ,
124126 drug_ids = response_data [DRUG_IDENTIFIER ].values ,
127+ tissues = response_data [TISSUE_IDENTIFIER ].values ,
125128 dataset_name = dataset_name ,
126129 )
127130
@@ -171,6 +174,7 @@ def _load_ctrpv(version: str, path_data: str = "data", measure: str = "LN_IC50_c
171174 response = response_data [measure ].values ,
172175 cell_line_ids = response_data [CELL_LINE_IDENTIFIER ].values ,
173176 drug_ids = response_data [DRUG_IDENTIFIER ].values ,
177+ tissues = response_data [TISSUE_IDENTIFIER ].values ,
174178 dataset_name = dataset_name ,
175179 )
176180
@@ -199,16 +203,19 @@ def load_ctrpv2(path_data: str = "data", measure: str = "LN_IC50_curvecurator")
199203 return _load_ctrpv ("2" , path_data , measure )
200204
201205
202- def load_custom (path_data : str | Path , measure : str = "response" ) -> DrugResponseDataset :
206+ def load_custom (
207+ path_data : str | Path , measure : str = "response" , tissue_column : str | None = None
208+ ) -> DrugResponseDataset :
203209 """
204210 Load custom dataset.
205211
206212 :param path_data: Path to location of custom dataset
207213 :param measure: The name of the column containing the measure to predict, default = "response"
214+ :param tissue_column: The name of the column containing the tissue type. If None, no tissue information is loaded.
208215
209216 :return: DrugResponseDataset containing response, cell line IDs, and drug IDs
210217 """
211- return DrugResponseDataset .from_csv (path_data , measure = measure )
218+ return DrugResponseDataset .from_csv (path_data , measure = measure , tissue_column = tissue_column )
212219
213220
214221AVAILABLE_DATASETS : dict [str , Callable ] = {
@@ -224,7 +231,12 @@ def load_custom(path_data: str | Path, measure: str = "response") -> DrugRespons
224231
225232@pipeline_function
226233def load_dataset (
227- dataset_name : str , path_data : str = "data" , measure : str = "response" , curve_curator : bool = False , cores : int = 1
234+ dataset_name : str ,
235+ path_data : str = "data" ,
236+ measure : str = "response" ,
237+ curve_curator : bool = False ,
238+ cores : int = 1 ,
239+ tissue_column : str | None = None ,
228240) -> DrugResponseDataset :
229241 """
230242 Load a dataset based on the dataset name.
@@ -243,6 +255,8 @@ def load_dataset(
243255 which is expected to exist at <path_data>/<dataset_name>/<dataset_name>_raw.csv. The fitted dataset will
244256 be stored in the same folder, in a file called <dataset_name>.csv
245257 :param cores: Number of cores to use for CurveCurator fitting. Only used when curve_curator is True, default = 1
258+ :param tissue_column: The name of the column containing the tissue type. If None, no tissue information is loaded.
259+ This is only used when loading a custom dataset. Default = None.
246260 :return: A DrugResponseDataset containing response, cell line IDs, drug IDs, and dataset name.
247261 :raises FileNotFoundError: If the custom dataset or raw viability data could not be found at the given path.
248262 """
@@ -263,5 +277,7 @@ def load_dataset(
263277 dataset_name = dataset_name ,
264278 cores = cores ,
265279 )
266- return load_custom (Path (path_data ) / dataset_name / f"{ dataset_name } .csv" , measure = measure )
280+ return load_custom (
281+ Path (path_data ) / dataset_name / f"{ dataset_name } .csv" , measure = measure , tissue_column = tissue_column
282+ )
267283 raise FileNotFoundError (f"Custom dataset does not exist at given path: { input_file } " )
0 commit comments