Skip to content

Commit f016850

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 56e4829 commit f016850

File tree

7 files changed

+13
-6
lines changed

7 files changed

+13
-6
lines changed

src/create_train_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33

44
import polars as pl
5+
56
from utils.functions import load_pickle
67

78
parser = argparse.ArgumentParser(description="Create train/val/test split.")

src/datasets.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55
from sklearn.model_selection import train_test_split
66
from torch.utils.data import Dataset
7+
78
from utils.functions import load_pickle, preview_data
89

910

src/evaluate.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import polars as pl
88
import shap
99
import toml
10-
from datasets import CollateTimeSeries, MIMIC4Dataset
1110
from fairlearn.metrics import (
1211
MetricFrame,
1312
count,
@@ -19,7 +18,6 @@
1918
selection_rate,
2019
)
2120
from lightning.pytorch import Trainer
22-
from models import MMModel
2321
from sklearn.metrics import (
2422
accuracy_score,
2523
average_precision_score,
@@ -28,6 +26,9 @@
2826
)
2927
from torch import concat
3028
from torch.utils.data import DataLoader
29+
30+
from datasets import CollateTimeSeries, MIMIC4Dataset
31+
from models import MMModel
3132
from utils.functions import load_pickle, read_from_txt
3233
from utils.preprocessing import transform_race
3334

src/postprocess.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55
import numpy as np
66
import polars as pl
77
import toml
8-
from datasets import MIMIC4Dataset
98
from fairlearn.postprocessing import ThresholdOptimizer, plot_threshold_optimizer
109
from sklearn.metrics import (
1110
accuracy_score,
1211
balanced_accuracy_score,
1312
confusion_matrix,
1413
)
14+
15+
from datasets import MIMIC4Dataset
1516
from utils.functions import load_pickle, read_from_txt
1617

1718
if __name__ == "__main__":

src/prepare_data.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import polars as pl
88
from tqdm import tqdm
9+
910
from utils.functions import scale_numeric_features
1011
from utils.preprocessing import (
1112
add_time_elapsed_to_events,

src/train.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,16 @@
22

33
import lightning as L
44
import toml
5-
from datasets import CollateFn, CollateTimeSeries, MIMIC4Dataset
65
from lightning.pytorch.callbacks import (
76
EarlyStopping,
87
LearningRateMonitor,
98
ModelCheckpoint,
109
)
1110
from lightning.pytorch.loggers import CSVLogger, WandbLogger
12-
from models import MMModel
1311
from torch.utils.data import DataLoader
12+
13+
from datasets import CollateFn, CollateTimeSeries, MIMIC4Dataset
14+
from models import MMModel
1415
from utils.functions import read_from_txt
1516

1617
if __name__ == "__main__":

src/train_rf.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import numpy as np
66
import toml
7-
from datasets import MIMIC4Dataset
87
from sklearn.ensemble import RandomForestClassifier
98
from sklearn.metrics import (
109
accuracy_score,
@@ -13,6 +12,8 @@
1312
roc_auc_score,
1413
)
1514
from sklearn.model_selection import GridSearchCV
15+
16+
from datasets import MIMIC4Dataset
1617
from utils.functions import read_from_txt
1718

1819
if __name__ == "__main__":

0 commit comments

Comments
 (0)