99from ..pipeline_function import pipeline_function
1010from .curvecurator import fit_curves
1111from .dataset import DrugResponseDataset
12- from .utils import download_dataset
12+ from .utils import CELL_LINE_IDENTIFIER , DRUG_IDENTIFIER , download_dataset
1313
1414
1515def load_gdsc1 (
1616 path_data : str = "data" ,
17- measure : str = "LN_IC50 " ,
18- file_name : str = "response_GDSC1 .csv" ,
17+ measure : str = "LN_IC50_curvecurator " ,
18+ file_name : str = "GDSC1 .csv" ,
1919 dataset_name : str = "GDSC1" ,
2020) -> DrugResponseDataset :
2121 """
@@ -32,18 +32,18 @@ def load_gdsc1(
3232 if not os .path .exists (path ):
3333 download_dataset (dataset_name , path_data , redownload = True )
3434
35- response_data = pd .read_csv (path )
36- response_data ["DRUG_NAME" ] = response_data ["DRUG_NAME" ].str .replace ("," , "" )
35+ response_data = pd .read_csv (path , dtype = { "pubchem_id" : str } )
36+ response_data [DRUG_IDENTIFIER ] = response_data [DRUG_IDENTIFIER ].str .replace ("," , "" )
3737
3838 return DrugResponseDataset (
3939 response = response_data [measure ].values ,
40- cell_line_ids = response_data ["CELL_LINE_NAME" ].values ,
41- drug_ids = response_data ["DRUG_NAME" ].values ,
40+ cell_line_ids = response_data [CELL_LINE_IDENTIFIER ].values ,
41+ drug_ids = response_data [DRUG_IDENTIFIER ].values ,
4242 dataset_name = dataset_name ,
4343 )
4444
4545
46- def load_gdsc2 (path_data : str = "data" , measure : str = "LN_IC50 " , file_name : str = "response_GDSC2 .csv" ):
46+ def load_gdsc2 (path_data : str = "data" , measure : str = "LN_IC50_curvecurator " , file_name : str = "GDSC2 .csv" ):
4747 """
4848 Loads the GDSC2 dataset.
4949
@@ -57,7 +57,7 @@ def load_gdsc2(path_data: str = "data", measure: str = "LN_IC50", file_name: str
5757
5858
5959def load_ccle (
60- path_data : str = "data" , measure : str = "LN_IC50 " , file_name : str = "response_CCLE .csv"
60+ path_data : str = "data" , measure : str = "LN_IC50_curvecurator " , file_name : str = "CCLE .csv"
6161) -> DrugResponseDataset :
6262 """
6363 Loads the CCLE dataset.
@@ -73,18 +73,18 @@ def load_ccle(
7373 if not os .path .exists (path ):
7474 download_dataset (dataset_name , path_data , redownload = True )
7575
76- response_data = pd .read_csv (path )
77- response_data ["DRUG_NAME" ] = response_data ["DRUG_NAME" ].str .replace ("," , "" )
76+ response_data = pd .read_csv (path , dtype = { "pubchem_id" : str } )
77+ response_data [DRUG_IDENTIFIER ] = response_data [DRUG_IDENTIFIER ].str .replace ("," , "" )
7878
7979 return DrugResponseDataset (
8080 response = response_data [measure ].values ,
81- cell_line_ids = response_data ["CELL_LINE_NAME" ].values ,
82- drug_ids = response_data ["DRUG_NAME" ].values ,
81+ cell_line_ids = response_data [CELL_LINE_IDENTIFIER ].values ,
82+ drug_ids = response_data [DRUG_IDENTIFIER ].values ,
8383 dataset_name = dataset_name ,
8484 )
8585
8686
87- def load_toy (path_data : str = "data" , measure : str = "response " ) -> DrugResponseDataset :
87+ def load_toy (path_data : str = "data" , measure : str = "LN_IC50_curvecurator " ) -> DrugResponseDataset :
8888 """
8989 Loads small Toy dataset, subsampled from GDSC1.
9090
@@ -94,20 +94,67 @@ def load_toy(path_data: str = "data", measure: str = "response") -> DrugResponse
9494 :return: DrugResponseDataset containing response, cell line IDs, and drug IDs.
9595 """
9696 dataset_name = "Toy_Data"
97- measure = "response" # overwrite this explicitly to avoid problems, should be changed in the future
9897 path = os .path .join (path_data , dataset_name , "toy_data.csv" )
9998 if not os .path .exists (path ):
10099 download_dataset (dataset_name , path_data , redownload = True )
101- response_data = pd .read_csv (path )
100+ response_data = pd .read_csv (path , dtype = { "pubchem_id" : str } )
102101
103102 return DrugResponseDataset (
104103 response = response_data [measure ].values ,
105- cell_line_ids = response_data ["cell_line_id" ].values ,
106- drug_ids = response_data ["drug_id" ].values ,
104+ cell_line_ids = response_data [CELL_LINE_IDENTIFIER ].values ,
105+ drug_ids = response_data [DRUG_IDENTIFIER ].values ,
107106 dataset_name = dataset_name ,
108107 )
109108
110109
110+ def _load_ctrpv (version : str , path_data : str = "data" , measure : str = "LN_IC50_curvecurator" ) -> DrugResponseDataset :
111+ """
112+ Load CTRPv1 dataset.
113+
114+ :param version: The version of the CTRP dataset to load.
115+ :param path_data: Path to location of CTRPv1 dataset
116+ :param measure: The name of the column containing the measure to predict, default = "response"
117+
118+ :return: DrugResponseDataset containing response, cell line IDs, and drug IDs
119+ """
120+ dataset_name = "CTRPv" + version
121+ path = os .path .join (path_data , dataset_name , f"{ dataset_name } .csv" )
122+ if not os .path .exists (path ):
123+ download_dataset (dataset_name , path_data , redownload = True )
124+ response_data = pd .read_csv (path , dtype = {"pubchem_id" : str })
125+
126+ return DrugResponseDataset (
127+ response = response_data [measure ].values ,
128+ cell_line_ids = response_data [CELL_LINE_IDENTIFIER ].values ,
129+ drug_ids = response_data [DRUG_IDENTIFIER ].values ,
130+ dataset_name = dataset_name ,
131+ )
132+
133+
134+ def load_ctrpv1 (path_data : str = "data" , measure : str = "LN_IC50_curvecurator" ) -> DrugResponseDataset :
135+ """
136+ Load CTRPv2 dataset.
137+
138+ :param path_data: Path to location of CTRPv2 dataset
139+ :param measure: The name of the column containing the measure to predict, default = "LN_IC50_curvecurator"
140+
141+ :return: DrugResponseDataset containing response, cell line IDs, and drug IDs
142+ """
143+ return _load_ctrpv ("1" , path_data , measure )
144+
145+
146+ def load_ctrpv2 (path_data : str = "data" , measure : str = "LN_IC50_curvecurator" ) -> DrugResponseDataset :
147+ """
148+ Load CTRPv2 dataset.
149+
150+ :param path_data: Path to location of CTRPv2 dataset
151+ :param measure: The name of the column containing the measure to predict, default: LN_IC50_curvecurator
152+
153+ :return: DrugResponseDataset containing response, cell line IDs, and drug IDs
154+ """
155+ return _load_ctrpv ("2" , path_data , measure )
156+
157+
111158def load_custom (path_data : str | Path , measure : str = "response" ) -> DrugResponseDataset :
112159 """
113160 Load custom dataset.
@@ -125,6 +172,8 @@ def load_custom(path_data: str | Path, measure: str = "response") -> DrugRespons
125172 "GDSC2" : load_gdsc2 ,
126173 "CCLE" : load_ccle ,
127174 "Toy_Data" : load_toy ,
175+ "CTRPv1" : load_ctrpv1 ,
176+ "CTRPv2" : load_ctrpv2 ,
128177}
129178
130179
0 commit comments