@@ -81,16 +81,27 @@ def __init__(
8181 """
8282
8383 self .original_graph_copy = deepcopy (G )
84- self .G = stdigraph .stDiGraph (G , additional_starts = additional_starts , additional_ends = additional_ends )
84+ self .sparsity_lambda = sparsity_lambda
85+
86+ if nx .is_directed_acyclic_graph (G ):
87+ self .is_acyclic = True
88+ self .G = stdigraph .stDiGraph (G , additional_starts = additional_starts , additional_ends = additional_ends )
89+ self .edges_to_ignore = set (edges_to_ignore ).union (self .G .source_sink_edges )
90+ else :
91+ self .G = G
92+ self .is_acyclic = False
93+ self .edges_to_ignore = set (edges_to_ignore )
94+ if self .sparsity_lambda != 0 :
95+ utils .logger .error (f"{ __name__ } : You cannot set sparsity_lambda != 0 for a graph with cycles." )
96+ raise ValueError (f"You cannot set sparsity_lambda != 0 for a graph with cycles." )
97+
8598 self .flow_attr = flow_attr
8699 if weight_type not in [int , float ]:
87100 utils .logger .error (f"{ __name__ } : weight_type must be either int or float, not { weight_type } " )
88101 raise ValueError (f"weight_type must be either int or float, not { weight_type } " )
89102 self .weight_type = weight_type
90103 self .solver_options = solver_options
91-
92- self .sparsity_lambda = sparsity_lambda
93- self .edges_to_ignore = set (edges_to_ignore ).union (self .G .source_sink_edges )
104+
94105 self .edge_error_scaling = edge_error_scaling
95106 # Checking that every entry in self.edge_error_scaling is between 0 and 1
96107 for key , value in self .edge_error_scaling .items ():
@@ -156,7 +167,7 @@ def __encode_flow(self):
156167
157168 # Adding flow conservation constraints
158169 for node in self .G .nodes ():
159- if node in [ self .G .source , self .G .sink ] :
170+ if self .G .in_degree ( node ) == 0 or self .G .out_degree ( node ) == 0 :
160171 continue
161172 # Flow conservation constraint
162173 self .solver .add_constraint (
@@ -212,9 +223,9 @@ def __encode_min_sum_errors_objective(self):
212223 self .edge_error_vars [(u , v )] * self .edge_error_scaling .get ((u , v ), 1 )
213224 for (u , v ) in self .G .edges ()
214225 if (u , v ) not in self .edges_to_ignore
215- ) + self .sparsity_lambda * self .solver .quicksum (
226+ ) + ( self .sparsity_lambda * self .solver .quicksum (
216227 self .edge_vars [(u , v )]
217- for (u , v ) in self .G .out_edges (self .G .source )
228+ for (u , v ) in self .G .out_edges (self .G .source )) if self . sparsity_lambda > 0 else 0
218229 ),
219230 sense = "minimize" ,
220231 )
@@ -286,9 +297,9 @@ def __encode_different_flow_values_and_objective(
286297 self .edge_error_vars [(u , v )] * self .edge_error_scaling .get ((u , v ), 1 )
287298 for (u , v ) in self .G .edges ()
288299 if (u , v ) not in self .edges_to_ignore
289- ) + self .sparsity_lambda * self .solver .quicksum (
300+ ) + ( self .sparsity_lambda * self .solver .quicksum (
290301 self .edge_vars [(u , v )]
291- for (u , v ) in self .G .out_edges (self .G .source )
302+ for (u , v ) in self .G .out_edges (self .G .source )) if self . sparsity_lambda > 0 else 0
292303 ) <= (1 + self .different_flow_values_epsilon ) * objective_value ,
293304 name = "epsilon_constraint" ,
294305 )
0 commit comments