Skip to content

Commit 2289ae4

Browse files
black
1 parent 94ce3c9 commit 2289ae4

File tree

2 files changed

+61
-61
lines changed

2 files changed

+61
-61
lines changed

causal_testing/surrogate/surrogate_search_algorithms.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212

1313
class GeneticSearchAlgorithm(SearchAlgorithm):
14-
""" Implementation of SearchAlgorithm class. Implements genetic search algorithm for surrogate models."""
14+
"""Implementation of SearchAlgorithm class. Implements genetic search algorithm for surrogate models."""
1515

1616
def __init__(self, delta=0.05, config: dict = None) -> None:
1717
super().__init__()
@@ -26,7 +26,7 @@ def __init__(self, delta=0.05, config: dict = None) -> None:
2626
}
2727

2828
def generate_fitness_functions(
29-
self, surrogate_models: list[CubicSplineRegressionEstimator]
29+
self, surrogate_models: list[CubicSplineRegressionEstimator]
3030
) -> list[SearchFitnessFunction]:
3131
fitness_functions = []
3232

@@ -56,7 +56,6 @@ def search(self, fitness_functions: list[SearchFitnessFunction], specification:
5656
solutions = []
5757

5858
for fitness_function in fitness_functions:
59-
6059
gene_types, gene_space = self.create_gene_types(fitness_function, specification)
6160

6261
ga = GA(
@@ -90,8 +89,9 @@ def search(self, fitness_functions: list[SearchFitnessFunction], specification:
9089
return max(solutions, key=itemgetter(1)) # This can be done better with fitness normalisation between edges
9190

9291
@staticmethod
93-
def create_gene_types(fitness_function: SearchFitnessFunction, specification: CausalSpecification) -> tuple[
94-
list, list]:
92+
def create_gene_types(
93+
fitness_function: SearchFitnessFunction, specification: CausalSpecification
94+
) -> tuple[list, list]:
9595
"""Generate the gene_types and gene_space for a given fitness function and specification
9696
:param fitness_function: Instance of SearchFitnessFunction
9797
:param specification: The Causal Specification (combination of Scenario and Causal Dag)"""

causal_testing/testing/estimators.py

Lines changed: 56 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,16 @@ class Estimator(ABC):
4040
"""
4141

4242
def __init__(
43-
# pylint: disable=too-many-arguments
44-
self,
45-
treatment: str,
46-
treatment_value: float,
47-
control_value: float,
48-
adjustment_set: set,
49-
outcome: str,
50-
df: pd.DataFrame = None,
51-
effect_modifiers: dict[str:Any] = None,
52-
alpha: float = 0.05,
43+
# pylint: disable=too-many-arguments
44+
self,
45+
treatment: str,
46+
treatment_value: float,
47+
control_value: float,
48+
adjustment_set: set,
49+
outcome: str,
50+
df: pd.DataFrame = None,
51+
effect_modifiers: dict[str:Any] = None,
52+
alpha: float = 0.05,
5353
):
5454
self.treatment = treatment
5555
self.treatment_value = treatment_value
@@ -90,16 +90,16 @@ class LogisticRegressionEstimator(Estimator):
9090
"""
9191

9292
def __init__(
93-
# pylint: disable=too-many-arguments
94-
self,
95-
treatment: str,
96-
treatment_value: float,
97-
control_value: float,
98-
adjustment_set: set,
99-
outcome: str,
100-
df: pd.DataFrame = None,
101-
effect_modifiers: dict[str:Any] = None,
102-
formula: str = None,
93+
# pylint: disable=too-many-arguments
94+
self,
95+
treatment: str,
96+
treatment_value: float,
97+
control_value: float,
98+
adjustment_set: set,
99+
outcome: str,
100+
df: pd.DataFrame = None,
101+
effect_modifiers: dict[str:Any] = None,
102+
formula: str = None,
103103
):
104104
super().__init__(treatment, treatment_value, control_value, adjustment_set, outcome, df, effect_modifiers)
105105

@@ -162,7 +162,7 @@ def estimate(self, data: pd.DataFrame, adjustment_config: dict = None) -> Regres
162162
return model.predict(x)
163163

164164
def estimate_control_treatment(
165-
self, adjustment_config: dict = None, bootstrap_size: int = 100
165+
self, adjustment_config: dict = None, bootstrap_size: int = 100
166166
) -> tuple[pd.Series, pd.Series]:
167167
"""Estimate the outcomes under control and treatment.
168168
@@ -280,17 +280,17 @@ class LinearRegressionEstimator(Estimator):
280280
"""
281281

282282
def __init__(
283-
# pylint: disable=too-many-arguments
284-
self,
285-
treatment: str,
286-
treatment_value: float,
287-
control_value: float,
288-
adjustment_set: set,
289-
outcome: str,
290-
df: pd.DataFrame = None,
291-
effect_modifiers: dict[Variable:Any] = None,
292-
formula: str = None,
293-
alpha: float = 0.05,
283+
# pylint: disable=too-many-arguments
284+
self,
285+
treatment: str,
286+
treatment_value: float,
287+
control_value: float,
288+
adjustment_set: set,
289+
outcome: str,
290+
df: pd.DataFrame = None,
291+
effect_modifiers: dict[Variable:Any] = None,
292+
formula: str = None,
293+
alpha: float = 0.05,
294294
):
295295
super().__init__(
296296
treatment, treatment_value, control_value, adjustment_set, outcome, df, effect_modifiers, alpha=alpha
@@ -445,19 +445,19 @@ class CubicSplineRegressionEstimator(LinearRegressionEstimator):
445445
"""
446446

447447
def __init__(
448-
# pylint: disable=too-many-arguments
449-
self,
450-
treatment: str,
451-
treatment_value: float,
452-
control_value: float,
453-
adjustment_set: set,
454-
outcome: str,
455-
basis: int,
456-
df: pd.DataFrame = None,
457-
effect_modifiers: dict[Variable:Any] = None,
458-
formula: str = None,
459-
alpha: float = 0.05,
460-
expected_relationship=None,
448+
# pylint: disable=too-many-arguments
449+
self,
450+
treatment: str,
451+
treatment_value: float,
452+
control_value: float,
453+
adjustment_set: set,
454+
outcome: str,
455+
basis: int,
456+
df: pd.DataFrame = None,
457+
effect_modifiers: dict[Variable:Any] = None,
458+
formula: str = None,
459+
alpha: float = 0.05,
460+
expected_relationship=None,
461461
):
462462
super().__init__(
463463
treatment, treatment_value, control_value, adjustment_set, outcome, df, effect_modifiers, formula, alpha
@@ -497,17 +497,17 @@ class InstrumentalVariableEstimator(Estimator):
497497
"""
498498

499499
def __init__(
500-
# pylint: disable=too-many-arguments
501-
self,
502-
treatment: str,
503-
treatment_value: float,
504-
control_value: float,
505-
adjustment_set: set,
506-
outcome: str,
507-
instrument: str,
508-
df: pd.DataFrame = None,
509-
intercept: int = 1,
510-
effect_modifiers: dict = None, # Not used (yet?). Needed for compatibility
500+
# pylint: disable=too-many-arguments
501+
self,
502+
treatment: str,
503+
treatment_value: float,
504+
control_value: float,
505+
adjustment_set: set,
506+
outcome: str,
507+
instrument: str,
508+
df: pd.DataFrame = None,
509+
intercept: int = 1,
510+
effect_modifiers: dict = None, # Not used (yet?). Needed for compatibility
511511
):
512512
super().__init__(treatment, treatment_value, control_value, adjustment_set, outcome, df, None)
513513
self.intercept = intercept

0 commit comments

Comments
 (0)