Skip to content

Commit 6f9a5a9

Browse files
committed
save the test predictions for the base line models
1 parent ddc3816 commit 6f9a5a9

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

src/cehrbert/evaluations/model_evaluators/frequency_model_evaluators.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,13 @@ def eval_model(self):
4040
self._model = self._model.fit(x, y)
4141
else:
4242
self._model.fit(x, y)
43-
compute_binary_metrics(self._model, test_data, self.get_model_metrics_folder())
43+
44+
compute_binary_metrics(
45+
self._model,
46+
test_data,
47+
self.get_model_metrics_folder(),
48+
evaluation_model_folder=self.get_model_test_prediction_folder(),
49+
)
4450
else:
4551
for train, test in self.k_fold(features=(inputs, age, person_ids), labels=labels):
4652
x, y = train
@@ -50,7 +56,12 @@ def eval_model(self):
5056
else:
5157
self._model.fit(x, y)
5258

53-
compute_binary_metrics(self._model, test, self.get_model_metrics_folder())
59+
compute_binary_metrics(
60+
self._model,
61+
test,
62+
self.get_model_metrics_folder(),
63+
evaluation_model_folder=self.get_model_test_prediction_folder(),
64+
)
5465

5566
def get_model_name(self):
5667
return type(self._model).__name__
@@ -130,7 +141,7 @@ def _create_model(self, *args, **kwargs):
130141
param_grid = [
131142
{
132143
"classifier": [LogisticRegression()],
133-
"classifier__penalty": ["l1", "l2"],
144+
"classifier__penalty": ["l2"],
134145
"classifier__C": np.logspace(-4, 4, 20),
135146
"classifier__solver": ["lbfgs"],
136147
"classifier__max_iter": [2000],

src/cehrbert/trainers/model_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def get_model_test_metrics_folder(self):
4040
return create_folder_if_not_exist(self.get_model_folder(), "test_metrics")
4141

4242
def get_model_test_prediction_folder(self):
43-
return create_folder_if_not_exist(self.get_model_folder(), "test_prediction")
43+
return create_folder_if_not_exist(self.get_model_folder(), "test_predictions")
4444

4545
def get_model_history_folder(self):
4646
return create_folder_if_not_exist(self.get_model_folder(), "history")

0 commit comments

Comments
 (0)