Skip to content

Commit 499ccba

Browse files
committed
mypy
1 parent 975afa6 commit 499ccba

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

drevalpy/visualization/create_report.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from collections.abc import Iterable
77
from typing import Union
88

9+
import numpy as np
910
import pandas as pd
1011

1112
from drevalpy.visualization.utils import (
@@ -42,8 +43,11 @@ def generate_reports_for_test_mode(
4243
:param path_data: Path to the dataset directory.
4344
:param result_path: Path to the results directory.
4445
"""
46+
path_data = pathlib.Path(path_data)
47+
result_path = pathlib.Path(result_path)
48+
4549
print(f"Generating report for {test_mode} ...")
46-
unique_algos: list[str] = draw_test_mode_plots(
50+
unique_algos_ndarray = draw_test_mode_plots(
4751
test_mode=test_mode,
4852
ev_res=evaluation_results,
4953
ev_res_per_drug=evaluation_results_per_drug,
@@ -52,6 +56,10 @@ def generate_reports_for_test_mode(
5256
path_data=path_data,
5357
result_path=result_path,
5458
)
59+
unique_algos: Iterable[str] = (
60+
list(unique_algos_ndarray) if isinstance(unique_algos_ndarray, (np.ndarray, tuple)) else unique_algos_ndarray
61+
)
62+
5563
unique_algos_set = set(unique_algos) - {
5664
"NaiveMeanEffectsPredictor",
5765
"NaivePredictor",
@@ -71,7 +79,7 @@ def generate_reports_for_test_mode(
7179
result_path=result_path,
7280
)
7381

74-
all_files: list[str] = []
82+
all_files = []
7583
for _, _, files in os.walk(f"{result_path}/{run_id}"):
7684
for file in files:
7785
if file.endswith("json") or (
@@ -122,7 +130,7 @@ def generate_reports_for_all_test_modes(
122130
)
123131

124132

125-
def render_report(
133+
def create_report(
126134
run_id: str,
127135
dataset: str,
128136
path_data: Union[str, pathlib.Path] = "data",
@@ -201,7 +209,7 @@ def main() -> None:
201209
parser.add_argument("--path_data", default="data", help="Path to the data")
202210
parser.add_argument("--result_path", default="results", help="Path to the results")
203211
args = parser.parse_args()
204-
render_report(args.run_id, args.dataset, args.path_data, args.result_path)
212+
create_report(args.run_id, args.dataset, args.path_data, args.result_path)
205213

206214

207215
if __name__ == "__main__":

0 commit comments

Comments
 (0)