-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathaggregate_dags_script.py
More file actions
88 lines (68 loc) · 2.55 KB
/
aggregate_dags_script.py
File metadata and controls
88 lines (68 loc) · 2.55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
"""Aggregating partial dags
This script aggregates the clinical dags by majority voting and creates a new dag for each
threshold in the range [1, 23]. The resultant dags are saved as csv files in original format within
the provided parent directory
"""
import argparse
from pathlib import Path
from typing import List
import numpy as np
import pandas as pd
from src.dataset import get_adjacency_matrix
from src.utils import aggregate_partial_dags, get_timestamp
def aggregate_and_save(
adj_matrices: List[np.ndarray],
majority_threshold: int,
node_names: List[str],
dirpath_save: Path,
) -> None:
"""Aggregate the partial dags and save as a csv file
Args:
adj_matrices: the list of partial dags to aggregate
majority_threshold: the threshold value to use for aggregation
node_names: the list of node names in correct sequence
dirpath_save: the directory where the output csv will be saved
"""
fpath_save = dirpath_save.joinpath(f"aggregated_dag_threshold_{majority_threshold}.csv")
adj_matrix_agg = aggregate_partial_dags(
partial_dags=adj_matrices, majority_threshold=majority_threshold
)
df_adj_matrix_agg = pd.DataFrame(adj_matrix_agg, columns=node_names, index=node_names)
df_adj_matrix_agg.to_csv(fpath_save)
def parse_args():
parser = argparse.ArgumentParser(description="Aggregate the partial dags")
parser.add_argument(
"--dirpath_partial_dags",
type=Path,
required=True,
help="the directory path where all the partial dags to aggregate are stored as csv files",
)
parser.add_argument(
"--dirpath_save",
type=Path,
required=True,
help="the dirpath where the set of aggregated dags will be saved",
)
args = parser.parse_args()
return args
def main():
"""The entry point to the scrip"""
args = parse_args()
df_dags = []
for fpath_dag in args.dirpath_partial_dags.iterdir():
df_dags.append(pd.read_csv(fpath_dag))
adj_matrices = []
for df_dag in df_dags:
adj_matrices.append(get_adjacency_matrix(df_dag))
node_names = df_dags[1].columns[1:]
dirpath_save = args.dirpath_save.joinpath(get_timestamp() + args.dirpath_partial_dags.name)
dirpath_save.mkdir()
for majority_threshold in range(1, len(adj_matrices)):
aggregate_and_save(
adj_matrices=adj_matrices,
majority_threshold=majority_threshold,
node_names=node_names,
dirpath_save=dirpath_save,
)
if __name__ == "__main__":
main()