-
Notifications
You must be signed in to change notification settings - Fork 5
Temporal shenanigans #286
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Temporal shenanigans #286
Changes from all commits
4951029
b321170
0d76e92
e2aef75
392387c
cff1776
bffe827
fbfe76a
6ceef1a
c20899d
e985641
41ea949
c1c64c2
a619069
4abe735
2b76d3c
e6284c5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
""" | ||
This module contains the Capability and TreatmentSequence classes to implement | ||
treatment sequences that operate over time. | ||
""" | ||
from typing import Any | ||
from causal_testing.specification.variable import Variable | ||
|
||
|
||
class Capability: | ||
""" | ||
Data class to encapsulate temporal interventions. | ||
""" | ||
|
||
def __init__(self, variable: Variable, value: Any, start_time: int, end_time: int): | ||
self.variable = variable | ||
self.value = value | ||
self.start_time = start_time | ||
self.end_time = end_time | ||
|
||
def __eq__(self, other): | ||
return ( | ||
isinstance(other, type(self)) | ||
and self.variable == other.variable | ||
and self.value == other.value | ||
and self.start_time == other.start_time | ||
and self.end_time == other.end_time | ||
) | ||
|
||
def __repr__(self): | ||
return f"({self.variable}, {self.value}, {self.start_time}-{self.end_time})" | ||
|
||
|
||
class TreatmentSequence: | ||
""" | ||
Class to represent a list of capabilities, i.e. a treatment regime. | ||
""" | ||
|
||
def __init__(self, timesteps_per_intervention, capabilities): | ||
self.timesteps_per_intervention = timesteps_per_intervention | ||
self.capabilities = [ | ||
Capability(var, val, t, t + timesteps_per_intervention) | ||
for (var, val), t in zip( | ||
capabilities, | ||
range( | ||
timesteps_per_intervention, | ||
(len(capabilities) * timesteps_per_intervention) + 1, | ||
timesteps_per_intervention, | ||
), | ||
) | ||
] | ||
# This is a bodge so that causal test adequacy works | ||
self.name = tuple(c.variable for c in self.capabilities) | ||
|
||
def set_value(self, index: int, value: float): | ||
""" | ||
Set the value of capability at the given index. | ||
:param index - the index of the element to update. | ||
:param value - the desired value of the capability. | ||
""" | ||
self.capabilities[index].value = value | ||
|
||
def copy(self): | ||
""" | ||
Return a deep copy of the capability list. | ||
""" | ||
strategy = TreatmentSequence( | ||
self.timesteps_per_intervention, | ||
[(c.variable, c.value) for c in self.capabilities], | ||
) | ||
return strategy | ||
|
||
def total_time(self): | ||
""" | ||
Calculate the total duration of the treatment strategy. | ||
""" | ||
return (len(self.capabilities) + 1) * self.timesteps_per_intervention | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,16 +2,20 @@ | |
This module contains code to measure various aspects of causal test adequacy. | ||
""" | ||
|
||
import logging | ||
from itertools import combinations | ||
from copy import deepcopy | ||
import pandas as pd | ||
from numpy.linalg import LinAlgError | ||
from lifelines.exceptions import ConvergenceError | ||
|
||
from causal_testing.testing.causal_test_suite import CausalTestSuite | ||
from causal_testing.data_collection.data_collector import DataCollector | ||
from causal_testing.specification.causal_dag import CausalDAG | ||
from causal_testing.testing.estimators import Estimator | ||
from causal_testing.testing.causal_test_case import CausalTestCase | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class DAGAdequacy: | ||
""" | ||
|
@@ -70,15 +74,21 @@ | |
- Zero kurtosis is optimal. | ||
""" | ||
|
||
# pylint: disable=too-many-instance-attributes | ||
def __init__( | ||
self, test_case: CausalTestCase, estimator: Estimator, data_collector: DataCollector, bootstrap_size: int = 100 | ||
self, | ||
test_case: CausalTestCase, | ||
estimator: Estimator, | ||
bootstrap_size: int = 100, | ||
group_by=None, | ||
): | ||
self.test_case = test_case | ||
self.estimator = estimator | ||
self.data_collector = data_collector | ||
self.kurtosis = None | ||
self.outcomes = None | ||
self.successful = None | ||
self.bootstrap_size = bootstrap_size | ||
self.group_by = group_by | ||
|
||
def measure_adequacy(self): | ||
""" | ||
|
@@ -87,11 +97,24 @@ | |
results = [] | ||
for i in range(self.bootstrap_size): | ||
estimator = deepcopy(self.estimator) | ||
estimator.df = estimator.df.sample(len(estimator.df), replace=True, random_state=i) | ||
# try: | ||
results.append(self.test_case.execute_test(estimator, self.data_collector)) | ||
# except np.LinAlgError: | ||
# continue | ||
|
||
if self.group_by is not None: | ||
ids = pd.Series(estimator.df[self.group_by].unique()) | ||
ids = ids.sample(len(ids), replace=True, random_state=i) | ||
estimator.df = estimator.df[estimator.df[self.group_by].isin(ids)] | ||
else: | ||
estimator.df = estimator.df.sample(len(estimator.df), replace=True, random_state=i) | ||
try: | ||
results.append(self.test_case.execute_test(estimator, None)) | ||
except LinAlgError: | ||
logger.warning("Adequacy LinAlgError") | ||
continue | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would logging any warnings here be helpful? Is it okay if it fails to append results? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've added in the warnings, but it's very OK to continue. What's happening here is that we're resampling the data and recalculating the estimate. If you don't have a lot of data, some of those samples may fail to converge, or fail in other ways to produce an estimate. This indicates that you don't have enough data, so I've now added "number of successes" as a third dimension to causal test adequacy. |
||
except ConvergenceError: | ||
logger.warning("Adequacy ConvergenceError") | ||
continue | ||
except ValueError as e: | ||
logger.warning(f"Adequacy ValueError: {e}") | ||
continue | ||
outcomes = [self.test_case.expected_causal_effect.apply(c) for c in results] | ||
results = pd.DataFrame(c.to_dict() for c in results)[["effect_estimate", "ci_low", "ci_high"]] | ||
|
||
|
@@ -111,8 +134,14 @@ | |
|
||
effect_estimate = pd.concat(results["effect_estimate"].tolist(), axis=1).transpose().reset_index(drop=True) | ||
self.kurtosis = effect_estimate.kurtosis() | ||
self.outcomes = sum(outcomes) | ||
self.outcomes = sum(filter(lambda x: x is not None, outcomes)) | ||
self.successful = sum(x is not None for x in outcomes) | ||
|
||
def to_dict(self): | ||
"Returns the adequacy object as a dictionary." | ||
return {"kurtosis": self.kurtosis.to_dict(), "bootstrap_size": self.bootstrap_size, "passing": self.outcomes} | ||
return { | ||
"kurtosis": self.kurtosis.to_dict(), | ||
"bootstrap_size": self.bootstrap_size, | ||
"passing": self.outcomes, | ||
"successful": self.successful, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jmafoster1 Out of curiosity, why is the total time defined as
(n+1)*dt
instead ofn*dt
? Do you assume the treatment strategy starts att=dt
instead oft=0
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@f-allian, yes exactly. This is slightly geared towards the water case study, where there is a grace period of dt before the first intervention takes place, and another period of dt after the last intervention before we can observe the final outcome. A slightly more general solution might be to have an extra
grace_period
parameter and then have the total time belen(capabilities) * timesteps_per_intervention + grace_period
, but I'll keep it as it is for now.