Skip to content

Commit 342860b

Browse files
Merge pull request #131 from daisybio/multi_features
Multi features
2 parents b36765d + 078a5e9 commit 342860b

File tree

5 files changed

+16
-20
lines changed

5 files changed

+16
-20
lines changed

create_report.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,8 +267,10 @@ def draw_per_grouping_algorithm_plots(
267267
if __name__ == "__main__":
268268
parser = argparse.ArgumentParser(description="Generate reports from evaluation results")
269269
parser.add_argument("--run_id", required=True, help="Run ID for the current execution")
270+
parser.add_argument("--dataset", required=True, help="Dataset name for which to render the result file")
270271
args = parser.parse_args()
271272
run_id = args.run_id
273+
dataset = args.dataset
272274

273275
# assert that the run_id folder exists
274276
if not os.path.exists(f"results/{run_id}"):
@@ -280,7 +282,7 @@ def draw_per_grouping_algorithm_plots(
280282
evaluation_results_per_drug,
281283
evaluation_results_per_cell_line,
282284
true_vs_pred,
283-
) = parse_results(path_to_results=f"results/{run_id}")
285+
) = parse_results(path_to_results=f"results/{run_id}", dataset=dataset)
284286

285287
# part of pipeline: EVALUATE_FINAL, COLLECT_RESULTS
286288
(

drevalpy/models/utils.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Utility functions for loading and processing data."""
22

33
import os.path
4-
import warnings
54

65
import numpy as np
76
import pandas as pd
@@ -94,17 +93,11 @@ def iterate_features(df: pd.DataFrame, feature_type: str) -> dict[str, dict[str,
9493
if cl in features.keys():
9594
continue
9695
rows = df.loc[cl]
96+
rows = rows.astype(float).to_numpy()
9797
if (len(rows.shape) > 1) and (rows.shape[0] > 1): # multiple rows returned
98-
warnings.warn(
99-
f"Multiple rows returned for Cell Line {cl} (and maybe others) "
100-
f"in feature {feature_type}, taking the first one.",
101-
stacklevel=2,
102-
)
103-
104-
rows = rows.iloc[0]
105-
# convert to float values
106-
rows = rows.astype(float)
107-
features[cl] = {feature_type: rows.values}
98+
# take mean
99+
rows = np.mean(rows, axis=0)
100+
features[cl] = {feature_type: rows}
108101
return features
109102

110103

drevalpy/visualization/utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,12 @@ def _parse_layout(f: TextIO, path_to_layout: str) -> None:
3737
f.write("".join(layout))
3838

3939

40-
def parse_results(path_to_results: str) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:
40+
def parse_results(path_to_results: str, dataset: str) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:
4141
"""
4242
Parse the results from the given directory.
4343
4444
:param path_to_results: path to the results directory
45+
:param dataset: dataset name, e.g., GDSC2
4546
:returns: evaluation results, evaluation results per drug, evaluation results per cell line, and true vs. predicted
4647
values
4748
"""
@@ -54,7 +55,7 @@ def parse_results(path_to_results: str) -> tuple[pd.DataFrame, pd.DataFrame, pd.
5455
# Convert the path to a forward-slash version for the regex (for Windows)
5556
result_dir_str = str(result_dir).replace("\\", "/")
5657
pattern = re.compile(
57-
rf"{result_dir_str}/(LPO|LCO|LDO)/[^/]+/(predictions|cross_study|randomization|robustness)/.*\.csv$"
58+
rf"{result_dir_str}/{dataset}/(LPO|LCO|LDO)/[^/]+/(predictions|cross_study|randomization|robustness)/.*\.csv$"
5859
)
5960
result_files = [file for file in result_files if pattern.match(str(file).replace("\\", "/"))]
6061

@@ -69,8 +70,9 @@ def parse_results(path_to_results: str) -> tuple[pd.DataFrame, pd.DataFrame, pd.
6970
rel_file = str(os.path.normpath(file.relative_to(result_dir))).replace("\\", "/")
7071
print(f'Evaluating file: "{rel_file}" ...')
7172
file_parts = rel_file.split("/")
72-
lpo_lco_ldo = file_parts[0]
73-
algorithm = file_parts[1]
73+
dataset = file_parts[0]
74+
lpo_lco_ldo = file_parts[1]
75+
algorithm = file_parts[2]
7476
(
7577
overall_eval,
7678
eval_results_per_drug,

tests/test_drp_model.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,10 +151,9 @@ def test_iterate_features() -> None:
151151
"""Test the iteration over features."""
152152
df = pd.DataFrame({"GeneA": [1, 2, 3, 2], "GeneB": [4, 5, 6, 2], "GeneC": [7, 8, 9, 2]})
153153
df.index = ["CellLine1", "CellLine2", "CellLine3", "CellLine1"]
154-
with pytest.warns(UserWarning):
155-
features = iterate_features(df, "gene_expression")
154+
features = iterate_features(df, "gene_expression")
156155
assert len(features) == 3
157-
assert np.all(features["CellLine1"]["gene_expression"] == [1, 4, 7])
156+
assert np.all(features["CellLine1"]["gene_expression"] == [1.5, 3, 4.5])
158157

159158

160159
def test_load_drug_ids_from_csv() -> None:

tests/test_run_suite.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def test_run_suite(args):
5353
evaluation_results_per_drug,
5454
evaluation_results_per_cell_line,
5555
true_vs_pred,
56-
) = parse_results(path_to_results=os.path.join(temp_dir.name, args.run_id, args.dataset_name))
56+
) = parse_results(path_to_results=os.path.join(temp_dir.name, args.run_id), dataset="Toy_Data")
5757

5858
(
5959
evaluation_results,

0 commit comments

Comments
 (0)