diff --git a/causal_testing/main.py b/causal_testing/main.py index 34f85d65..594be0e6 100644 --- a/causal_testing/main.py +++ b/causal_testing/main.py @@ -140,7 +140,7 @@ def load_dag(self) -> CausalDAG: """ logger.info(f"Loading DAG from {self.paths.dag_path}") dag = CausalDAG(str(self.paths.dag_path), ignore_cycles=self.ignore_cycles) - logger.info(f"DAG loaded with {len(dag.graph.nodes)} nodes and {len(dag.graph.edges)} edges") + logger.info(f"DAG loaded with {len(dag.nodes)} nodes and {len(dag.edges)} edges") return dag def _read_dataframe(self, data_path): @@ -172,18 +172,18 @@ def create_variables(self) -> None: """ Create variable objects from DAG nodes based on their connectivity. """ - for node_name, node_data in self.dag.graph.nodes(data=True): + for node_name, node_data in self.dag.nodes(data=True): if node_name not in self.data.columns and not node_data.get("hidden", False): raise ValueError(f"Node {node_name} missing from data. Should it be marked as hidden?") dtype = self.data.dtypes.get(node_name) # If node has no incoming edges, it's an input - if self.dag.graph.in_degree(node_name) == 0: + if self.dag.in_degree(node_name) == 0: self.variables["inputs"][node_name] = Input(name=node_name, datatype=dtype) # Otherwise it's an output - if self.dag.graph.in_degree(node_name) > 0: + if self.dag.in_degree(node_name) > 0: self.variables["outputs"][node_name] = Output(name=node_name, datatype=dtype) def create_scenario_and_specification(self) -> None: diff --git a/causal_testing/specification/causal_dag.py b/causal_testing/specification/causal_dag.py index 07fe42b6..42327051 100644 --- a/causal_testing/specification/causal_dag.py +++ b/causal_testing/specification/causal_dag.py @@ -4,93 +4,21 @@ import logging from itertools import combinations -from random import sample -from typing import Union +from typing import Union, Set, Generator import networkx as nx -import pydot from causal_testing.testing.base_test_case import BaseTestCase from .scenario import Scenario from .variable import Output + Node = Union[str, int] # Node type hint: A node is a string or an int logger = logging.getLogger(__name__) -def list_all_min_sep( - graph: nx.Graph, - treatment_node: Node, - outcome_node: Node, - treatment_node_set: set[Node], - outcome_node_set: set[Node], -): - """A backtracking algorithm for listing all minimal treatment-outcome separators in an undirected graph. - - Reference: (Space-optimal, backtracking algorithms to list the minimal vertex separators of a graph, Ken Takata, - 2013, p.5, ListMinSep procedure). - - :param graph: An undirected graph. - :param treatment_node: The node corresponding to the treatment variable we wish to separate from the output. - :param outcome_node: The node corresponding to the outcome variable we wish to separate from the input. - :param treatment_node_set: Set of treatment nodes. - :param outcome_node_set: Set of outcome nodes. - :return: A list of minimal-sized sets of variables which separate treatment and outcome in the undirected graph. - """ - # 1. Compute the close separator of the treatment set - close_separator_set = close_separator(graph, treatment_node, outcome_node, treatment_node_set) - - # 2. Use the close separator to separate the graph and obtain the connected components (connected sub-graphs) - components_graph = graph.copy() - components_graph.remove_nodes_from(close_separator_set) - graph_components = nx.connected_components(components_graph) - - # 3. Find the connected component that contains the treatment node - treatment_connected_component_node_set = set() - for component in graph_components: - if treatment_node in component: - treatment_connected_component_node_set = component - - # 4. Confirm that the connected component containing the treatment node is disjoint with the outcome node set - if not treatment_connected_component_node_set.intersection(outcome_node_set): - # 5. Update the treatment node set to the set of nodes in the connected component containing the treatment node - treatment_node_set = treatment_connected_component_node_set - - # 6. Obtain the neighbours of the new treatment node set (this excludes the treatment nodes themselves) - treatment_node_set_neighbours = ( - set.union(*[set(nx.neighbors(graph, node)) for node in treatment_node_set]) - treatment_node_set - ) - - # 7. Check that there exists at least one neighbour of the treatment nodes that is not in the outcome node set - if treatment_node_set_neighbours.difference(outcome_node_set): - # 7.1. If so, sample a random node from the set of treatment nodes' neighbours not in the outcome node set - node = set(sample(sorted(treatment_node_set_neighbours.difference(outcome_node_set)), 1)) - - # 7.2. Add this node to the treatment node set and recurse (left branch) - yield from list_all_min_sep( - graph, - treatment_node, - outcome_node, - treatment_node_set.union(node), - outcome_node_set, - ) - - # 7.3. Add this node to the outcome node set and recurse (right branch) - yield from list_all_min_sep( - graph, - treatment_node, - outcome_node, - treatment_node_set, - outcome_node_set.union(node), - ) - else: - # 8. If all neighbours of the treatments nodes are in the outcome node set, return the set of treatment - # node neighbours - yield treatment_node_set_neighbours - - def close_separator( graph: nx.Graph, treatment_node: Node, outcome_node: Node, treatment_node_set: set[Node] ) -> set[Node]: @@ -109,40 +37,103 @@ def close_separator( :param treatment_node_set: The set of variables containing the treatment node ({treatment_node}). :return: A treatment_node-outcome_node separator whose vertices are adjacent to those in treatments. """ - treatment_neighbours = set.union(*[set(nx.neighbors(graph, treatment)) for treatment in treatment_node_set]) - components_graph = graph.copy() - components_graph.remove_nodes_from(treatment_neighbours) + treatment_neighbours = {x for treatment in treatment_node_set for x in graph[treatment]} + components_graph = graph.subgraph(set(graph.nodes) - treatment_neighbours) graph_components = nx.connected_components(components_graph) for component in graph_components: if outcome_node in component: - neighbours_of_variables_in_component = set.union( - *[set(nx.neighbors(graph, variable)) for variable in component] - ) + neighbours_of_variables_in_component = {x for variable in component for x in graph[variable]} # For this algorithm, the neighbours of a node do not include the node itself neighbours_of_variables_in_component = neighbours_of_variables_in_component.difference(component) return neighbours_of_variables_in_component raise ValueError(f"No {treatment_node}-{outcome_node} separator in the graph.") +def list_all_min_sep( + graph: nx.Graph, + treatment_node: str, + outcome_node: str, + treatment_node_set: Set, + outcome_node_set: Set, +) -> Generator[Set, None, None]: + """A backtracking algorithm for listing all minimal treatment-outcome separators in an undirected graph. + + Reference: (Space-optimal, backtracking algorithms to list the minimal vertex separators of a graph, Ken Takata, + 2013, p.5, ListMinSep procedure). + + :param graph: An undirected graph. + :param treatment_node: The node corresponding to the treatment variable we wish to separate from the output. + :param outcome_node: The node corresponding to the outcome variable we wish to separate from the input. + :param treatment_node_set: Set of treatment nodes. + :param outcome_node_set: Set of outcome nodes. + :return: A generator of minimal-sized sets of variables which separate treatment and outcome in the undirected + graph. + """ + # 1. Compute the close separator of the treatment set + close_separator_set = close_separator(graph, treatment_node, outcome_node, treatment_node_set) + + # 2. Use the close separator to separate the graph and obtain the connected components (connected sub-graphs) + components_graph = graph.subgraph(set(graph.nodes) - close_separator_set) + + # 3. Find the component containing the treatment node + treatment_component = set() + for component in nx.connected_components(components_graph): + if treatment_node in component: + treatment_component = component + break + + # 4. Confirm that the connected component containing the treatment node is disjoint with the outcome node set + if treatment_component.intersection(outcome_node_set): + return + + # 5. Obtain the neighbours of the new treatment node set (this excludes the treatment nodes themselves) + neighbour_nodes = { + neighbour for node in treatment_component for neighbour in graph[node] if neighbour not in treatment_component + } + + # 6. Check that there exists at least one neighbour of the treatment nodes that is not in the outcome node set + remaining = neighbour_nodes - outcome_node_set + if remaining: + # 6.1. If so, sample a random node from the set of treatment nodes' neighbours not in the outcome node set + chosen = {remaining.pop()} + # 6.2. Add this node to the treatment node set and recurse (left branch) + yield from list_all_min_sep( + graph, + treatment_node, + outcome_node, + treatment_component.union(chosen), + outcome_node_set, + ) + # 6.3. Add this node to the outcome node set and recurse (right branch) + yield from list_all_min_sep( + graph, + treatment_node, + outcome_node, + treatment_component, + outcome_node_set.union(chosen), + ) + else: + # Step 7: All neighbours are in outcome set — we found a separator + yield neighbour_nodes + + class CausalDAG(nx.DiGraph): """A causal DAG is a directed acyclic graph in which nodes represent random variables and edges represent causality between a pair of random variables. We implement a CausalDAG as a networkx DiGraph with an additional check that ensures it is acyclic. A CausalDAG must be specified as a dot file. """ - def __init__(self, dot_path: str = None, ignore_cycles: bool = False, **attr): + def __init__(self, file_path: str = None, ignore_cycles: bool = False, **attr): super().__init__(**attr) self.ignore_cycles = ignore_cycles - if dot_path: - with open(dot_path, "r", encoding="utf-8") as file: - dot_content = file.read().replace("\n", "") - # Previously, we used pydot_graph_from_file() to read in the dot_path directly, however, - # this method does not currently have a way of removing spurious nodes. - # Workaround: Read in the file using open(), remove new lines, and then create the pydot_graph. - pydot_graph = pydot.graph_from_dot_data(dot_content) - self.graph = nx.DiGraph(nx.drawing.nx_pydot.from_pydot(pydot_graph[0])) - else: - self.graph = nx.DiGraph() + if file_path: + if file_path.endswith(".dot"): + graph = nx.DiGraph(nx.nx_pydot.read_dot(file_path)) + elif file_path.endswith(".xml"): + graph = nx.graphml.read_graphml(file_path) + else: + raise ValueError(f"Unsupported file extension {file_path}. We only support .dot and .xml files.") + self.update(graph) if not self.is_acyclic(): if ignore_cycles: @@ -152,22 +143,6 @@ def __init__(self, dot_path: str = None, ignore_cycles: bool = False, **attr): else: raise nx.HasACycle("Invalid Causal DAG: contains a cycle.") - @property - def nodes(self) -> list: - """ - Get the nodes of the DAG. - :returns: The nodes of the DAG. - """ - return self.graph.nodes - - @property - def edges(self) -> list: - """ - Get the edges of the DAG. - :returns: The edges of the DAG. - """ - return self.graph.edges - def check_iv_assumptions(self, treatment, outcome, instrument) -> bool: """ Checks the three instrumental variable assumptions, raising a @@ -176,11 +151,11 @@ def check_iv_assumptions(self, treatment, outcome, instrument) -> bool: :return Boolean True if the three IV assumptions hold. """ # (i) Instrument is associated with treatment - if nx.d_separated(self.graph, {instrument}, {treatment}, set()): + if nx.d_separated(self, {instrument}, {treatment}, set()): raise ValueError(f"Instrument {instrument} is not associated with treatment {treatment} in the DAG") # (ii) Instrument does not affect outcome except through its potential effect on treatment - if not all((treatment in path for path in nx.all_simple_paths(self.graph, source=instrument, target=outcome))): + if not all((treatment in path for path in nx.all_simple_paths(self, source=instrument, target=outcome))): raise ValueError( f"Instrument {instrument} affects the outcome {outcome} other than through the treatment {treatment}" ) @@ -189,11 +164,9 @@ def check_iv_assumptions(self, treatment, outcome, instrument) -> bool: for cause in self.nodes: # Exclude self-cycles due to breaking changes in NetworkX > 3.2 - outcome_paths = ( - list(nx.all_simple_paths(self.graph, source=cause, target=outcome)) if cause != outcome else [] - ) + outcome_paths = list(nx.all_simple_paths(self, source=cause, target=outcome)) if cause != outcome else [] instrument_paths = ( - list(nx.all_simple_paths(self.graph, source=cause, target=instrument)) if cause != instrument else [] + list(nx.all_simple_paths(self, source=cause, target=instrument)) if cause != instrument else [] ) if len(instrument_paths) > 0 and len(outcome_paths) > 0: raise ValueError(f"Instrument {instrument} and outcome {outcome} share common causes") @@ -207,7 +180,7 @@ def add_edge(self, u_of_edge: Node, v_of_edge: Node, **attr): :param v_of_edge: To node :param attr: Attributes """ - self.graph.add_edge(u_of_edge, v_of_edge, **attr) + super().add_edge(u_of_edge, v_of_edge, **attr) if not self.is_acyclic(): raise nx.HasACycle("Invalid Causal DAG: contains a cycle.") @@ -215,7 +188,7 @@ def cycle_nodes(self) -> list: """Get the nodes involved in any cycles. :return: A list containing all nodes involved in a cycle. """ - return [node for cycle in nx.simple_cycles(self.graph) for node in cycle] + return [node for cycle in nx.simple_cycles(self) for node in cycle] def is_acyclic(self) -> bool: """Checks if the graph is acyclic. @@ -242,12 +215,11 @@ def get_proper_backdoor_graph(self, treatments: list[str], outcomes: list[str]) if var not in self.nodes: raise IndexError(f"{var} not a node in Causal DAG.\nValid nodes are{self.nodes}.") - proper_backdoor_graph = self.copy() - nodes_on_proper_causal_path = proper_backdoor_graph.proper_causal_pathway(treatments, outcomes) - edges_to_remove = [ - (u, v) for (u, v) in proper_backdoor_graph.graph.out_edges(treatments) if v in nodes_on_proper_causal_path - ] - proper_backdoor_graph.graph.remove_edges_from(edges_to_remove) + nodes_on_proper_causal_path = self.proper_causal_pathway(treatments, outcomes) + proper_backdoor_graph = CausalDAG() + edges_to_remove = {(u, v) for (u, v) in self.out_edges(treatments) if v in nodes_on_proper_causal_path} + proper_backdoor_graph.add_nodes_from(self.nodes) + proper_backdoor_graph.add_edges_from(e for e in self.edges if e not in edges_to_remove) return proper_backdoor_graph def get_ancestor_graph(self, treatments: list[str], outcomes: list[str]) -> CausalDAG: @@ -264,17 +236,9 @@ def get_ancestor_graph(self, treatments: list[str], outcomes: list[str]) -> Caus :param outcomes: A list of outcome variables to include in the ancestral graph (and their ancestors). :return: An ancestral graph relative to the set of variables X union Y. """ - ancestor_graph = self.copy() - treatment_ancestors = set.union( - *[nx.ancestors(ancestor_graph.graph, treatment).union({treatment}) for treatment in treatments] - ) - outcome_ancestors = set.union( - *[nx.ancestors(ancestor_graph.graph, outcome).union({outcome}) for outcome in outcomes] - ) - variables_to_keep = treatment_ancestors.union(outcome_ancestors) - variables_to_remove = set(self.nodes).difference(variables_to_keep) - ancestor_graph.graph.remove_nodes_from(variables_to_remove) - return ancestor_graph + variables_to_keep = set(treatments + outcomes) + variables_to_keep.update({ancestor for var in treatments + outcomes for ancestor in nx.ancestors(self, var)}) + return self.subgraph(variables_to_keep) def get_indirect_graph(self, treatments: list[str], outcomes: list[str]) -> CausalDAG: """ @@ -286,14 +250,10 @@ def get_indirect_graph(self, treatments: list[str], outcomes: list[str]) -> Caus :return: The indirect graph with edges pointing from X to Y removed. :rtype: CausalDAG """ - gback = self.copy() - ee = [] - for s in treatments: - for t in outcomes: - if (s, t) in gback.edges: - ee.append((s, t)) - for v1, v2 in ee: - gback.graph.remove_edge(v1, v2) + ee = {(s, t) for s in treatments for t in outcomes if (s, t) in self.edges} + gback = CausalDAG() + gback.add_nodes_from(self.nodes) + gback.add_edges_from(filter(lambda x: x not in ee, self.edges)) return gback def direct_effect_adjustment_sets( @@ -322,7 +282,7 @@ def direct_effect_adjustment_sets( indirect_graph = self.get_indirect_graph(treatments, outcomes) ancestor_graph = indirect_graph.get_ancestor_graph(treatments, outcomes) - gam = nx.moral_graph(ancestor_graph.graph) + gam = nx.moral_graph(ancestor_graph) edges_to_add = [("TREATMENT", treatment) for treatment in treatments] edges_to_add += [("OUTCOME", outcome) for outcome in outcomes] @@ -333,7 +293,7 @@ def direct_effect_adjustment_sets( min_seps.remove(set(outcomes)) return sorted(list(filter(lambda sep: not sep.intersection(nodes_to_ignore), min_seps))) - def enumerate_minimal_adjustment_sets(self, treatments: list[str], outcomes: list[str]) -> list[set[str]]: + def enumerate_minimal_adjustment_sets(self, treatments: list[str], outcomes: list[str]) -> Generator[set[str]]: """Get the smallest possible set of variables that blocks all back-door paths between all pairs of treatments and outcomes. @@ -354,46 +314,38 @@ def enumerate_minimal_adjustment_sets(self, treatments: list[str], outcomes: lis :return: A list of strings representing the minimal adjustment set. """ - # 1. Construct the proper back-door graph's ancestor moral graph + # Step 1: Build the proper back-door graph and its moralized ancestor graph proper_backdoor_graph = self.get_proper_backdoor_graph(treatments, outcomes) ancestor_proper_backdoor_graph = proper_backdoor_graph.get_ancestor_graph(treatments, outcomes) - moralised_proper_backdoor_graph = nx.moral_graph(ancestor_proper_backdoor_graph.graph) - - # 2. Add an edge X^m to treatment nodes and Y^m to outcome nodes - edges_to_add = [("TREATMENT", treatment) for treatment in treatments] - edges_to_add += [("OUTCOME", outcome) for outcome in outcomes] - moralised_proper_backdoor_graph.add_edges_from(edges_to_add) - - # 3. Remove treatment and outcome nodes from graph and connect neighbours - treatment_neighbours = set.union( - *[set(nx.neighbors(moralised_proper_backdoor_graph, treatment)) for treatment in treatments] - ) - set(treatments) - outcome_neighbours = set.union( - *[set(nx.neighbors(moralised_proper_backdoor_graph, outcome)) for outcome in outcomes] - ) - set(outcomes) - - neighbour_edges_to_add = list(combinations(treatment_neighbours, 2)) + list(combinations(outcome_neighbours, 2)) - moralised_proper_backdoor_graph.add_edges_from(neighbour_edges_to_add) - - # 4. Find all minimal separators of X^m and Y^m using Takata's algorithm for listing minimal separators - treatment_node_set = {"TREATMENT"} - outcome_node_set = set(nx.neighbors(moralised_proper_backdoor_graph, "OUTCOME")).union({"OUTCOME"}) - minimum_adjustment_sets = list( - list_all_min_sep( - moralised_proper_backdoor_graph, - "TREATMENT", - "OUTCOME", - treatment_node_set, - outcome_node_set, - ) + moralised_proper_backdoor_graph = nx.moral_graph(ancestor_proper_backdoor_graph) + + # Step 2: Add artificial TREATMENT and OUTCOME nodes + moralised_proper_backdoor_graph.add_edges_from([("TREATMENT", t) for t in treatments]) + moralised_proper_backdoor_graph.add_edges_from([("OUTCOME", y) for y in outcomes]) + + # Step 3: Remove treatment and outcome nodes from graph and connect neighbours + treatment_neighbors = { + node for t in treatments for node in moralised_proper_backdoor_graph[t] if node not in treatments + } + moralised_proper_backdoor_graph.add_edges_from(combinations(treatment_neighbors, 2)) + + outcome_neighbors = { + node for o in outcomes for node in moralised_proper_backdoor_graph[o] if node not in outcomes + } + moralised_proper_backdoor_graph.add_edges_from(combinations(outcome_neighbors, 2)) + + # Step 4: Find all minimal separators of X^m and Y^m using Takata's algorithm for listing minimal separators + sep_candidates = list_all_min_sep( + moralised_proper_backdoor_graph, + "TREATMENT", + "OUTCOME", + {"TREATMENT"}, + set(moralised_proper_backdoor_graph["OUTCOME"]) | {"OUTCOME"}, + ) + return filter( + lambda s: self.constructive_backdoor_criterion(proper_backdoor_graph, treatments, outcomes, s), + sep_candidates, ) - valid_minimum_adjustment_sets = [ - adj - for adj in minimum_adjustment_sets - if self.constructive_backdoor_criterion(proper_backdoor_graph, treatments, outcomes, adj) - ] - - return valid_minimum_adjustment_sets def adjustment_set_is_minimal(self, treatments: list[str], outcomes: list[str], adjustment_set: set[str]) -> bool: """Given a list of treatments X, a list of outcomes Y, and an adjustment set Z, determine whether Z is the @@ -418,8 +370,7 @@ def adjustment_set_is_minimal(self, treatments: list[str], outcomes: list[str], # Remove each variable one at a time and return false if constructive back-door criterion remains satisfied for variable in adjustment_set: - smaller_adjustment_set = adjustment_set.copy() - smaller_adjustment_set.remove(variable) + smaller_adjustment_set = {a for a in adjustment_set if a != variable} if not smaller_adjustment_set: # Treat None as the empty set smaller_adjustment_set = set() if self.constructive_backdoor_criterion( @@ -434,7 +385,11 @@ def adjustment_set_is_minimal(self, treatments: list[str], outcomes: list[str], return True def constructive_backdoor_criterion( - self, proper_backdoor_graph: CausalDAG, treatments: list[str], outcomes: list[str], covariates: list[str] + self, + proper_backdoor_graph: CausalDAG, + treatments: list[str], + outcomes: list[str], + covariates: list[str], ) -> bool: """A variation of Pearl's back-door criterion applied to a proper backdoor graph which enables more efficient computation of minimal adjustment sets for the effect of a set of treatments on a set of outcomes. @@ -455,34 +410,31 @@ def constructive_backdoor_criterion( :return: True or False, depending on whether the set of covariates satisfies the constructive back-door criterion. """ - # Condition (1) - proper_causal_path_vars = self.proper_causal_pathway(treatments, outcomes) - if proper_causal_path_vars: - descendents_of_proper_casual_paths = set.union( - *[ - set.union( - nx.descendants(self.graph, proper_causal_path_var), - {proper_causal_path_var}, - ) - for proper_causal_path_var in proper_causal_path_vars - ] + + # Condition (1): Covariates must not be descendants of any node on a proper causal path + proper_path_vars = self.proper_causal_pathway(treatments, outcomes) + if proper_path_vars: + # Collect all descendants including each proper causal path var itself + descendents_of_proper_casual_paths = set(proper_path_vars) + descendents_of_proper_casual_paths.update( + {node for var in proper_path_vars for node in nx.descendants(self, var)} ) if not set(covariates).issubset(set(self.nodes).difference(descendents_of_proper_casual_paths)): + # Covariates intersect with disallowed descendants — fail condition 1 logger.info( - "Failed Condition 1: Z=%s **is** a descendent of some variable on a proper causal " - "path between X=%s and Y=%s.", + "Failed Condition 1: " + "Z=%s **is** a descendant of variables on a proper causal path between X=%s and Y=%s.", covariates, treatments, outcomes, ) return False - # Condition (2) - if not nx.d_separated(proper_backdoor_graph.graph, set(treatments), set(outcomes), set(covariates)): + # Condition (2): Z must d-separate X and Y in the proper back-door graph + if not nx.d_separated(proper_backdoor_graph, set(treatments), set(outcomes), set(covariates)): logger.info( - "Failed Condition 2: Z=%s **does not** d-separate X=%s and Y=%s in" - " the proper back-door graph relative to X and Y.", + "Failed Condition 2: Z=%s **does not** d-separate X=%s and Y=%s in the proper back-door graph.", covariates, treatments, outcomes, @@ -503,12 +455,12 @@ def proper_causal_pathway(self, treatments: list[str], outcomes: list[str]) -> l :return vars_on_proper_causal_pathway: Return a list of the variables on the proper causal pathway between treatments and outcomes. """ - treatments_descendants = set.union( - *[nx.descendants(self.graph, treatment).union({treatment}) for treatment in treatments] - ) - treatments_descendants_without_treatments = set(treatments_descendants).difference(treatments) + treatments_descendants_without_treatments = { + x for treatment in treatments for x in nx.descendants(self, treatment) if x not in treatments + } backdoor_graph = self.get_backdoor_graph(set(treatments)) - outcome_ancestors = set.union(*[nx.ancestors(backdoor_graph, outcome).union({outcome}) for outcome in outcomes]) + outcome_ancestors = set(outcomes) + outcome_ancestors.update({x for outcome in outcomes for x in nx.ancestors(backdoor_graph, outcome)}) nodes_on_proper_causal_paths = treatments_descendants_without_treatments.intersection(outcome_ancestors) return nodes_on_proper_causal_paths @@ -519,9 +471,10 @@ def get_backdoor_graph(self, treatments: list[str]) -> CausalDAG: :param treatments: The set of treatments whose outgoing edges will be deleted. :return: A back-door graph corresponding to the given causal DAG and set of treatments. """ - outgoing_edges = self.graph.out_edges(treatments) - backdoor_graph = self.graph.copy() - backdoor_graph.remove_edges_from(outgoing_edges) + outgoing_edges = self.out_edges(treatments) + backdoor_graph = CausalDAG() + backdoor_graph.add_nodes_from(self.nodes) + backdoor_graph.add_edges_from(filter(lambda x: x not in outgoing_edges, self.edges)) return backdoor_graph def depends_on_outputs(self, node: Node, scenario: Scenario) -> bool: @@ -538,7 +491,7 @@ def depends_on_outputs(self, node: Node, scenario: Scenario) -> bool: """ if isinstance(scenario.variables[node], Output): return True - return any((self.depends_on_outputs(n, scenario) for n in self.graph.predecessors(node))) + return any((self.depends_on_outputs(n, scenario) for n in self.predecessors(node))) @staticmethod def remove_hidden_adjustment_sets(minimal_adjustment_sets: list[str], scenario: Scenario): @@ -557,8 +510,9 @@ def identification(self, base_test_case: BaseTestCase, scenario: Scenario = None :return minimal_adjustment_set: The smallest set of variables which can be adjusted for to obtain a causal estimate as opposed to a purely associational estimate. """ + # Naive method to guarantee termination when we have cycles if self.ignore_cycles: - return set(self.graph.predecessors(base_test_case.treatment_variable.name)) + return set(self.predecessors(base_test_case.treatment_variable.name)) minimal_adjustment_sets = [] if base_test_case.effect == "total": minimal_adjustment_sets = self.enumerate_minimal_adjustment_sets( @@ -574,10 +528,7 @@ def identification(self, base_test_case: BaseTestCase, scenario: Scenario = None if scenario is not None: minimal_adjustment_sets = self.remove_hidden_adjustment_sets(minimal_adjustment_sets, scenario) - if len(minimal_adjustment_sets) == 0: - return set() - - minimal_adjustment_set = min(minimal_adjustment_sets, key=len) + minimal_adjustment_set = min(minimal_adjustment_sets, key=len, default=set()) return set(minimal_adjustment_set) def to_dot_string(self) -> str: diff --git a/causal_testing/surrogate/causal_surrogate_assisted.py b/causal_testing/surrogate/causal_surrogate_assisted.py index a7d436a4..8a97d2cf 100644 --- a/causal_testing/surrogate/causal_surrogate_assisted.py +++ b/causal_testing/surrogate/causal_surrogate_assisted.py @@ -125,7 +125,7 @@ def generate_surrogates( surrogate_models = [] for u, v in specification.causal_dag.edges: - edge_metadata = specification.causal_dag.graph.adj[u][v] + edge_metadata = specification.causal_dag.adj[u][v] if "included" in edge_metadata: from_var = specification.scenario.variables.get(u) to_var = specification.scenario.variables.get(v) diff --git a/causal_testing/testing/metamorphic_relation.py b/causal_testing/testing/metamorphic_relation.py index 55b4381f..2bfa9d12 100644 --- a/causal_testing/testing/metamorphic_relation.py +++ b/causal_testing/testing/metamorphic_relation.py @@ -133,13 +133,13 @@ def generate_metamorphic_relation( # Create a ShouldNotCause relation for each pair of nodes that are not directly connected if ((u, v) not in dag.edges) and ((v, u) not in dag.edges): # Case 1: U --> ... --> V - if u in nx.ancestors(dag.graph, v): + if u in nx.ancestors(dag, v): adj_sets = dag.direct_effect_adjustment_sets([u], [v], nodes_to_ignore=nodes_to_ignore) if adj_sets: metamorphic_relations.append(ShouldNotCause(BaseTestCase(u, v), list(adj_sets[0]))) # Case 2: V --> ... --> U - elif v in nx.ancestors(dag.graph, u): + elif v in nx.ancestors(dag, u): adj_sets = dag.direct_effect_adjustment_sets([v], [u], nodes_to_ignore=nodes_to_ignore) if adj_sets: metamorphic_relations.append(ShouldNotCause(BaseTestCase(v, u), list(adj_sets[0]))) @@ -221,7 +221,7 @@ def generate_causal_tests( causal_dag = CausalDAG(dag_path, ignore_cycles=ignore_cycles) dag_nodes_to_test = [ - node for node in causal_dag.nodes if nx.get_node_attributes(causal_dag.graph, "test", default=True)[node] + node for node in causal_dag.nodes if nx.get_node_attributes(causal_dag, "test", default=True)[node] ] if not causal_dag.is_acyclic() and ignore_cycles: @@ -241,7 +241,7 @@ def generate_causal_tests( tests = [ relation.to_json_stub(**json_stub_kargs) for relation in relations - if len(list(causal_dag.graph.predecessors(relation.base_test_case.outcome_variable))) > 0 + if len(list(causal_dag.predecessors(relation.base_test_case.outcome_variable))) > 0 ] logger.info(f"Generated {len(tests)} tests. Saving to {output_path}.") diff --git a/tests/main_tests/test_main.py b/tests/main_tests/test_main.py index 36ce4709..3e93f88c 100644 --- a/tests/main_tests/test_main.py +++ b/tests/main_tests/test_main.py @@ -93,7 +93,7 @@ def test_load_data_query(self): def test_load_dag_missing_node(self): framework = CausalTestingFramework(self.paths) framework.setup() - framework.dag.graph.add_node("missing") + framework.dag.add_node("missing") with self.assertRaises(ValueError): framework.create_variables() diff --git a/tests/resources/data/dag.xml b/tests/resources/data/dag.xml new file mode 100644 index 00000000..fb577278 --- /dev/null +++ b/tests/resources/data/dag.xml @@ -0,0 +1,16 @@ + + + + + + + + + + + + + diff --git a/tests/specification_tests/test_causal_dag.py b/tests/specification_tests/test_causal_dag.py index a619122d..fb808d44 100644 --- a/tests/specification_tests/test_causal_dag.py +++ b/tests/specification_tests/test_causal_dag.py @@ -20,12 +20,18 @@ def setUp(self) -> None: with open(self.dag_dot_path, "w") as f: f.write(dag_dot) + def test_graphml(self): + dot_dag = CausalDAG(self.dag_dot_path) + xml_dag = CausalDAG(os.path.join("tests", "resources", "data", "dag.xml")) + self.assertEqual(dot_dag.nodes, xml_dag.nodes) + self.assertEqual(dot_dag.edges, xml_dag.edges) + def test_enumerate_minimal_adjustment_sets(self): """Test whether enumerate_minimal_adjustment_sets lists all possible minimum sized adjustment sets.""" causal_dag = CausalDAG(self.dag_dot_path) xs, ys = ["X"], ["Y"] adjustment_sets = causal_dag.enumerate_minimal_adjustment_sets(xs, ys) - self.assertEqual([{"Z"}], adjustment_sets) + self.assertEqual([{"Z"}], list(adjustment_sets)) def tearDown(self) -> None: shutil.rmtree(self.temp_dir_path) @@ -46,19 +52,19 @@ def test_valid_iv(self): def test_unrelated_instrument(self): causal_dag = CausalDAG(self.dag_dot_path) - causal_dag.graph.remove_edge("I", "X") + causal_dag.remove_edge("I", "X") with self.assertRaises(ValueError): causal_dag.check_iv_assumptions("X", "Y", "I") def test_direct_cause(self): causal_dag = CausalDAG(self.dag_dot_path) - causal_dag.graph.add_edge("I", "Y") + causal_dag.add_edge("I", "Y") with self.assertRaises(ValueError): causal_dag.check_iv_assumptions("X", "Y", "I") def test_common_cause(self): causal_dag = CausalDAG(self.dag_dot_path) - causal_dag.graph.add_edge("U", "I") + causal_dag.add_edge("U", "I") with self.assertRaises(ValueError): causal_dag.check_iv_assumptions("X", "Y", "I") @@ -279,12 +285,12 @@ def test_enumerate_minimal_adjustment_sets(self): causal_dag = CausalDAG(self.dag_dot_path) xs, ys = ["X1", "X2"], ["Y"] adjustment_sets = causal_dag.enumerate_minimal_adjustment_sets(xs, ys) - self.assertEqual([{"Z"}], adjustment_sets) + self.assertEqual([{"Z"}], list(adjustment_sets)) def test_enumerate_minimal_adjustment_sets_multiple(self): """Test whether enumerate_minimal_adjustment_sets lists all minimum adjustment sets if multiple are possible.""" causal_dag = CausalDAG() - causal_dag.graph.add_edges_from( + causal_dag.add_edges_from( [ ("X1", "X2"), ("X2", "V"), @@ -308,7 +314,7 @@ def test_enumerate_minimal_adjustment_sets_multiple(self): def test_enumerate_minimal_adjustment_sets_two_adjustments(self): """Test whether enumerate_minimal_adjustment_sets lists all possible minimum adjustment sets of arity two.""" causal_dag = CausalDAG() - causal_dag.graph.add_edges_from( + causal_dag.add_edges_from( [ ("X1", "X2"), ("X2", "V"), @@ -335,7 +341,7 @@ def test_enumerate_minimal_adjustment_sets_two_adjustments(self): def test_dag_with_non_character_nodes(self): """Test identification for a DAG whose nodes are not just characters (strings of length greater than 1).""" causal_dag = CausalDAG() - causal_dag.graph.add_edges_from( + causal_dag.add_edges_from( [ ("va", "ba"), ("ba", "ia"), @@ -350,7 +356,7 @@ def test_dag_with_non_character_nodes(self): ) xs, ys = ["ba"], ["da"] adjustment_sets = causal_dag.enumerate_minimal_adjustment_sets(xs, ys) - self.assertEqual(adjustment_sets, [{"aa"}, {"la"}, {"va"}]) + self.assertEqual(list(adjustment_sets), [{"aa"}, {"la"}, {"va"}]) def tearDown(self) -> None: shutil.rmtree(self.temp_dir_path) @@ -475,3 +481,12 @@ def test_hidden_varaible_adjustment_sets(self): def tearDown(self) -> None: shutil.rmtree(self.temp_dir_path) + + +def time_it(label, func, *args, **kwargs): + import time + + start = time.time() + result = func(*args, **kwargs) + print(f"{label} took {time.time() - start:.6f} seconds") + return result diff --git a/tests/testing_tests/test_metamorphic_relations.py b/tests/testing_tests/test_metamorphic_relations.py index 6d838126..a57c0c23 100644 --- a/tests/testing_tests/test_metamorphic_relations.py +++ b/tests/testing_tests/test_metamorphic_relations.py @@ -48,7 +48,7 @@ def test_should_not_cause_json_stub(self): """Test if the ShouldCause MR passes all metamorphic tests where the DAG perfectly represents the program and there is only a single input.""" causal_dag = CausalDAG(self.dag_dot_path) - causal_dag.graph.remove_nodes_from(["X2", "X3"]) + causal_dag.remove_nodes_from(["X2", "X3"]) adj_set = list(causal_dag.direct_effect_adjustment_sets(["X1"], ["Z"])[0]) should_not_cause_mr = ShouldNotCause(BaseTestCase("X1", "Z"), adj_set) self.assertEqual( @@ -70,7 +70,7 @@ def test_should_not_cause_logistic_json_stub(self): """Test if the ShouldCause MR passes all metamorphic tests where the DAG perfectly represents the program and there is only a single input.""" causal_dag = CausalDAG(self.dag_dot_path) - causal_dag.graph.remove_nodes_from(["X2", "X3"]) + causal_dag.remove_nodes_from(["X2", "X3"]) adj_set = list(causal_dag.direct_effect_adjustment_sets(["X1"], ["Z"])[0]) should_not_cause_mr = ShouldNotCause(BaseTestCase("X1", "Z"), adj_set) self.assertEqual( @@ -94,7 +94,7 @@ def test_should_cause_json_stub(self): """Test if the ShouldCause MR passes all metamorphic tests where the DAG perfectly represents the program and there is only a single input.""" causal_dag = CausalDAG(self.dag_dot_path) - causal_dag.graph.remove_nodes_from(["X2", "X3"]) + causal_dag.remove_nodes_from(["X2", "X3"]) adj_set = list(causal_dag.direct_effect_adjustment_sets(["X1"], ["Z"])[0]) should_cause_mr = ShouldCause(BaseTestCase("X1", "Z"), adj_set) self.assertEqual( @@ -115,7 +115,7 @@ def test_should_cause_logistic_json_stub(self): """Test if the ShouldCause MR passes all metamorphic tests where the DAG perfectly represents the program and there is only a single input.""" causal_dag = CausalDAG(self.dag_dot_path) - causal_dag.graph.remove_nodes_from(["X2", "X3"]) + causal_dag.remove_nodes_from(["X2", "X3"]) adj_set = list(causal_dag.direct_effect_adjustment_sets(["X1"], ["Z"])[0]) should_cause_mr = ShouldCause(BaseTestCase("X1", "Z"), adj_set) self.assertEqual( @@ -265,8 +265,7 @@ def test_generate_causal_tests_ignore_cycles(self): map( lambda x: x.to_json_stub(skip=True), filter( - lambda relation: len(list(dcg.graph.predecessors(relation.base_test_case.outcome_variable))) - > 0, + lambda relation: len(list(dcg.predecessors(relation.base_test_case.outcome_variable))) > 0, relations, ), ) @@ -285,8 +284,7 @@ def test_generate_causal_tests(self): map( lambda x: x.to_json_stub(skip=True), filter( - lambda relation: len(list(dag.graph.predecessors(relation.base_test_case.outcome_variable))) - > 0, + lambda relation: len(list(dag.predecessors(relation.base_test_case.outcome_variable))) > 0, relations, ), )