Skip to content

Commit 16c468d

Browse files
committed
pylint
1 parent 3dc9ef3 commit 16c468d

File tree

1 file changed

+71
-83
lines changed

1 file changed

+71
-83
lines changed

causal_testing/specification/causal_dag.py

Lines changed: 71 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,11 @@
55
import logging
66
from itertools import combinations
77
from random import sample
8-
from typing import Union, Generator
8+
from typing import Union, Generator, Set
99

1010
import networkx as nx
1111
import pydot
1212

13-
from typing import Generator, Set
14-
1513
from causal_testing.testing.base_test_case import BaseTestCase
1614

1715
from .scenario import Scenario
@@ -136,13 +134,7 @@ def __init__(self, dot_path: str = None, ignore_cycles: bool = False, **attr):
136134
super().__init__(**attr)
137135
self.ignore_cycles = ignore_cycles
138136
if dot_path:
139-
with open(dot_path, "r", encoding="utf-8") as file:
140-
dot_content = file.read().replace("\n", "")
141-
# Previously, we used pydot_graph_from_file() to read in the dot_path directly, however,
142-
# this method does not currently have a way of removing spurious nodes.
143-
# Workaround: Read in the file using open(), remove new lines, and then create the pydot_graph.
144-
pydot_graph = pydot.graph_from_dot_data(dot_content)
145-
self.graph = nx.DiGraph(nx.drawing.nx_pydot.from_pydot(pydot_graph[0]))
137+
self.graph = nx.DiGraph(nx.nx_pydot.read_dot(dot_path))
146138
else:
147139
self.graph = nx.DiGraph()
148140

@@ -576,10 +568,7 @@ def identification(self, base_test_case: BaseTestCase, scenario: Scenario = None
576568
if scenario is not None:
577569
minimal_adjustment_sets = self.remove_hidden_adjustment_sets(minimal_adjustment_sets, scenario)
578570

579-
if len(minimal_adjustment_sets) == 0:
580-
return set()
581-
582-
minimal_adjustment_set = min(minimal_adjustment_sets, key=len)
571+
minimal_adjustment_set = min(minimal_adjustment_sets, key=len, default=set())
583572
return set(minimal_adjustment_set)
584573

585574
def to_dot_string(self) -> str:
@@ -639,7 +628,7 @@ def enumerate_minimal_adjustment_sets(self, treatments: list[str], outcomes: lis
639628
moralised_proper_backdoor_graph.add_edges_from(combinations(outcome_neighbors, 2))
640629

641630
# Step 4: Find all minimal separators of X^m and Y^m using Takata's algorithm for listing minimal separators
642-
sep_candidates = list_all_min_sep_opt(
631+
sep_candidates = self.list_all_min_sep_opt(
643632
moralised_proper_backdoor_graph,
644633
"TREATMENT",
645634
"OUTCOME",
@@ -689,7 +678,8 @@ def constructive_backdoor_criterion(
689678
if not set(covariates).issubset(set(self.nodes).difference(descendents_of_proper_casual_paths)):
690679
# Covariates intersect with disallowed descendants — fail condition 1
691680
logger.info(
692-
"Failed Condition 1: Z=%s **is** a descendant of variables on a proper causal path between X=%s and Y=%s.",
681+
"Failed Condition 1: "
682+
"Z=%s **is** a descendant of variables on a proper causal path between X=%s and Y=%s.",
693683
covariates,
694684
treatments,
695685
outcomes,
@@ -708,76 +698,74 @@ def constructive_backdoor_criterion(
708698

709699
return True
710700

701+
def list_all_min_sep_opt(
702+
self,
703+
graph: nx.Graph,
704+
treatment_node: str,
705+
outcome_node: str,
706+
treatment_node_set: Set,
707+
outcome_node_set: Set,
708+
) -> Generator[Set, None, None]:
709+
"""A backtracking algorithm for listing all minimal treatment-outcome separators in an undirected graph.
710+
711+
Reference: (Space-optimal, backtracking algorithms to list the minimal vertex separators of a graph, Ken Takata,
712+
2013, p.5, ListMinSep procedure).
713+
714+
:param graph: An undirected graph.
715+
:param treatment_node: The node corresponding to the treatment variable we wish to separate from the output.
716+
:param outcome_node: The node corresponding to the outcome variable we wish to separate from the input.
717+
:param treatment_node_set: Set of treatment nodes.
718+
:param outcome_node_set: Set of outcome nodes.
719+
:return: A generator of minimal-sized sets of variables which separate treatment and outcome in the undirected
720+
graph.
721+
"""
722+
# 1. Compute the close separator of the treatment set
723+
close_separator_set = close_separator(graph, treatment_node, outcome_node, treatment_node_set)
711724

712-
def list_all_min_sep_opt(
713-
graph: nx.Graph,
714-
treatment_node,
715-
outcome_node,
716-
treatment_node_set: Set,
717-
outcome_node_set: Set,
718-
) -> Generator[Set, None, None]:
719-
"""A backtracking algorithm for listing all minimal treatment-outcome separators in an undirected graph.
720-
721-
Reference: (Space-optimal, backtracking algorithms to list the minimal vertex separators of a graph, Ken Takata,
722-
2013, p.5, ListMinSep procedure).
725+
# 2. Use the close separator to separate the graph and obtain the connected components (connected sub-graphs)
726+
components_graph = graph.copy()
727+
components_graph.remove_nodes_from(close_separator_set)
723728

724-
:param graph: An undirected graph.
725-
:param treatment_node: The node corresponding to the treatment variable we wish to separate from the output.
726-
:param outcome_node: The node corresponding to the outcome variable we wish to separate from the input.
727-
:param treatment_node_set: Set of treatment nodes.
728-
:param outcome_node_set: Set of outcome nodes.
729-
:return: A generator of minimal-sized sets of variables which separate treatment and outcome in the undirected graph.
730-
"""
731-
# 1. Compute the close separator of the treatment set
732-
close_separator_set = close_separator(graph, treatment_node, outcome_node, treatment_node_set)
729+
# 3. Find the component containing the treatment node
730+
treatment_component = set()
731+
for component in nx.connected_components(components_graph):
732+
if treatment_node in component:
733+
treatment_component = component
734+
break
733735

734-
# 2. Use the close separator to separate the graph and obtain the connected components (connected sub-graphs)
735-
components_graph = graph.copy()
736-
components_graph.remove_nodes_from(close_separator_set)
736+
# 4. Confirm that the connected component containing the treatment node is disjoint with the outcome node set
737+
if treatment_component.intersection(outcome_node_set):
738+
return
737739

738-
# 3. Find the component containing the treatment node
739-
treatment_component = set()
740-
for component in nx.connected_components(components_graph):
741-
if treatment_node in component:
742-
treatment_component = component
743-
break
740+
# 5. Update the treatment node set to the set of nodes in the connected component containing the treatment node
741+
treatment_node_set = treatment_component
744742

745-
# 4. Confirm that the connected component containing the treatment node is disjoint with the outcome node set
746-
if treatment_component.intersection(outcome_node_set):
747-
raise ValueError(
748-
f"Connected component {treatment_component} containing the treatment node is not disjoint with "
749-
f"the outcome node set {outcome_node_set}"
750-
)
743+
# 6. Obtain the neighbours of the new treatment node set (this excludes the treatment nodes themselves)
744+
neighbour_nodes = {
745+
neighbour for node in treatment_node_set for neighbour in graph[node] if neighbour not in treatment_node_set
746+
}
751747

752-
# 5. Update the treatment node set to the set of nodes in the connected component containing the treatment node
753-
treatment_node_set = treatment_component
754-
755-
# 6. Obtain the neighbours of the new treatment node set (this excludes the treatment nodes themselves)
756-
neighbour_nodes = {
757-
neighbour for node in treatment_node_set for neighbour in graph[node] if neighbour not in treatment_node_set
758-
}
759-
760-
# 7. Check that there exists at least one neighbour of the treatment nodes that is not in the outcome node set
761-
remaining = neighbour_nodes - outcome_node_set
762-
if remaining:
763-
# 7.1. If so, sample a random node from the set of treatment nodes' neighbours not in the outcome node set
764-
chosen = sample(sorted(remaining), 1)
765-
# 7.2. Add this node to the treatment node set and recurse (left branch)
766-
yield from list_all_min_sep_opt(
767-
graph,
768-
treatment_node,
769-
outcome_node,
770-
treatment_node_set.union(chosen),
771-
outcome_node_set,
772-
)
773-
# 7.3. Add this node to the outcome node set and recurse (right branch)
774-
yield from list_all_min_sep_opt(
775-
graph,
776-
treatment_node,
777-
outcome_node,
778-
treatment_node_set,
779-
outcome_node_set.union(chosen),
780-
)
781-
else:
782-
# Step 8: All neighbours are in outcome set — we found a separator
783-
yield neighbour_nodes
748+
# 7. Check that there exists at least one neighbour of the treatment nodes that is not in the outcome node set
749+
remaining = neighbour_nodes - outcome_node_set
750+
if remaining:
751+
# 7.1. If so, sample a random node from the set of treatment nodes' neighbours not in the outcome node set
752+
chosen = sample(sorted(remaining), 1)
753+
# 7.2. Add this node to the treatment node set and recurse (left branch)
754+
yield from self.list_all_min_sep_opt(
755+
graph,
756+
treatment_node,
757+
outcome_node,
758+
treatment_node_set.union(chosen),
759+
outcome_node_set,
760+
)
761+
# 7.3. Add this node to the outcome node set and recurse (right branch)
762+
yield from self.list_all_min_sep_opt(
763+
graph,
764+
treatment_node,
765+
outcome_node,
766+
treatment_node_set,
767+
outcome_node_set.union(chosen),
768+
)
769+
else:
770+
# Step 8: All neighbours are in outcome set — we found a separator
771+
yield neighbour_nodes

0 commit comments

Comments
 (0)