Skip to content

Commit 4e8b0c9

Browse files
committed
Extra coverage
1 parent 8f32552 commit 4e8b0c9

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

causal_testing/generation/abstract_causal_test_case.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,11 @@ def generate_concrete_tests(
230230
pre_break = True
231231
break
232232
if issubclass(self.treatment_variable.datatype, Enum) and set(
233-
itertools.product(self.treatment_variable.datatype, self.treatment_variable.datatype)
233+
{
234+
(x, y)
235+
for x, y in itertools.product(self.treatment_variable.datatype, self.treatment_variable.datatype)
236+
if x != y
237+
}
234238
).issubset(zip(control_values, treatment_values)):
235239
pre_break = True
236240
break

tests/generation_tests/test_abstract_test_case.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,15 +76,15 @@ def test_generate_boolean_concrete_test_cases(self):
7676
abstract = AbstractCausalTestCase(
7777
scenario=scenario,
7878
intervention_constraints={
79-
And(scenario.treatment_variables[self.X5.name].z3 == True, scenario.variables[self.X5.name].z3 == False)
79+
scenario.treatment_variables[self.X5.name].z3 != scenario.variables[self.X5.name].z3
8080
},
8181
treatment_variable=self.X5,
8282
expected_causal_effect={self.Y: Positive()},
8383
effect_modifiers=None,
8484
)
8585
concrete_tests, runs = abstract.generate_concrete_tests(2)
86-
assert len(concrete_tests) == 1, "Expected 1 concrete test"
87-
assert len(runs) == 1, "Expected 1 run"
86+
assert len(concrete_tests) == 2, "Expected 2 concrete test"
87+
assert len(runs) == 2, "Expected 2 run"
8888

8989
def test_generate_enum_concrete_test_cases(self):
9090
scenario = Scenario({self.Car})
@@ -98,7 +98,7 @@ def test_generate_enum_concrete_test_cases(self):
9898
expected_causal_effect={self.Y: Positive()},
9999
effect_modifiers=None,
100100
)
101-
concrete_tests, runs = abstract.generate_concrete_tests(2)
101+
concrete_tests, runs = abstract.generate_concrete_tests(10)
102102
assert len(concrete_tests) == 2, "Expected 2 concrete tests"
103103
assert len(runs) == 2, "Expected 2 runs"
104104

0 commit comments

Comments
 (0)