Skip to content

Commit 36d8139

Browse files
committed
check fit for aggregate
1 parent f7a6ffd commit 36d8139

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

doubleml/did/did_multi.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -778,14 +778,15 @@ def sensitivity_benchmark(self, benchmarking_set, fit_args=None):
778778

779779
def aggregate(self, aggregation="simple"):
780780
if not isinstance(aggregation, str):
781-
raise TypeError(
782-
"aggregation must be a string. " f"{str(aggregation)} of type {type(aggregation)} was passed."
783-
)
781+
raise TypeError("aggregation must be a string. " f"{str(aggregation)} of type {type(aggregation)} was passed.")
784782
valid_aggregations = ["simple"]
785783
if aggregation not in valid_aggregations:
786-
raise ValueError(
787-
f"aggregation must be one of {valid_aggregations}. " f"{str(aggregation)} was passed."
788-
)
784+
raise ValueError(f"aggregation must be one of {valid_aggregations}. " f"{str(aggregation)} was passed.")
785+
if self.framework is None:
786+
raise ValueError("Apply fit() before aggregate().")
787+
788+
if aggregation == "simple":
789+
pass
789790
pass
790791

791792
def _fit_model(self, i_gt, n_jobs_cv=None, store_predictions=True, store_models=False, external_predictions_dict=None):
@@ -888,9 +889,7 @@ def _rename_external_predictions(self, external_predictions):
888889
return ext_pred_dict
889890

890891
def _calc_nuisance_loss(self):
891-
nuisance_loss = {
892-
learner: np.full((self.n_rep, self.n_gt_atts), np.nan) for learner in self.modellist[0].params_names
893-
}
892+
nuisance_loss = {learner: np.full((self.n_rep, self.n_gt_atts), np.nan) for learner in self.modellist[0].params_names}
894893
for i_model, model in enumerate(self.modellist):
895894
for learner in self.modellist[0].params_names:
896895
for i_rep in range(self.n_rep):

doubleml/did/tests/test_did_multi_exceptions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,3 +122,8 @@ def test_exceptions_aggregate():
122122
msg = "aggregation must be one of \\['simple'\\]. invalid was passed."
123123
with pytest.raises(ValueError, match=msg):
124124
dml_obj.aggregate(aggregation="invalid")
125+
126+
# test without fit()
127+
msg = r"Apply fit\(\) before aggregate\(\)."
128+
with pytest.raises(ValueError, match=msg):
129+
dml_obj.aggregate()

0 commit comments

Comments
 (0)