66from collections .abc import Iterable
77from typing import Union
88
9+ import numpy as np
910import pandas as pd
1011
1112from drevalpy .visualization .utils import (
@@ -42,8 +43,11 @@ def generate_reports_for_test_mode(
4243 :param path_data: Path to the dataset directory.
4344 :param result_path: Path to the results directory.
4445 """
46+ path_data = pathlib .Path (path_data )
47+ result_path = pathlib .Path (result_path )
48+
4549 print (f"Generating report for { test_mode } ..." )
46- unique_algos : list [ str ] = draw_test_mode_plots (
50+ unique_algos_ndarray = draw_test_mode_plots (
4751 test_mode = test_mode ,
4852 ev_res = evaluation_results ,
4953 ev_res_per_drug = evaluation_results_per_drug ,
@@ -52,6 +56,10 @@ def generate_reports_for_test_mode(
5256 path_data = path_data ,
5357 result_path = result_path ,
5458 )
59+ unique_algos : Iterable [str ] = (
60+ list (unique_algos_ndarray ) if isinstance (unique_algos_ndarray , (np .ndarray , tuple )) else unique_algos_ndarray
61+ )
62+
5563 unique_algos_set = set (unique_algos ) - {
5664 "NaiveMeanEffectsPredictor" ,
5765 "NaivePredictor" ,
@@ -71,7 +79,7 @@ def generate_reports_for_test_mode(
7179 result_path = result_path ,
7280 )
7381
74- all_files : list [ str ] = []
82+ all_files = []
7583 for _ , _ , files in os .walk (f"{ result_path } /{ run_id } " ):
7684 for file in files :
7785 if file .endswith ("json" ) or (
@@ -122,7 +130,7 @@ def generate_reports_for_all_test_modes(
122130 )
123131
124132
125- def render_report (
133+ def create_report (
126134 run_id : str ,
127135 dataset : str ,
128136 path_data : Union [str , pathlib .Path ] = "data" ,
@@ -201,7 +209,7 @@ def main() -> None:
201209 parser .add_argument ("--path_data" , default = "data" , help = "Path to the data" )
202210 parser .add_argument ("--result_path" , default = "results" , help = "Path to the results" )
203211 args = parser .parse_args ()
204- render_report (args .run_id , args .dataset , args .path_data , args .result_path )
212+ create_report (args .run_id , args .dataset , args .path_data , args .result_path )
205213
206214
207215if __name__ == "__main__" :
0 commit comments