Skip to content

Commit 75d067e

Browse files
committed
Add a direction parameter to the find_all method
Earlier the direction was hard coded, which made it unclear that the method was only finding outgoing nodes. Signed-off-by: Sahas Subramanian <[email protected]>
1 parent 6f43828 commit 75d067e

File tree

7 files changed

+56
-13
lines changed

7 files changed

+56
-13
lines changed

src/graph/formulas/generators/battery.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,12 @@ where
2929
let inverter_ids = if let Some(battery_ids) = battery_ids {
3030
Self::find_inverter_ids(graph, &battery_ids)?
3131
} else {
32-
graph.find_all(graph.root_id, |node| node.is_battery_inverter(), false)?
32+
graph.find_all(
33+
graph.root_id,
34+
|node| node.is_battery_inverter(),
35+
petgraph::Direction::Outgoing,
36+
false,
37+
)?
3338
};
3439
Ok(Self {
3540
graph,

src/graph/formulas/generators/chp.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,12 @@ where
2929
let chp_ids = if let Some(chp_ids) = chp_ids {
3030
chp_ids
3131
} else {
32-
graph.find_all(graph.root_id, |node| node.is_chp(), false)?
32+
graph.find_all(
33+
graph.root_id,
34+
|node| node.is_chp(),
35+
petgraph::Direction::Outgoing,
36+
false,
37+
)?
3338
};
3439
Ok(Self { graph, chp_ids })
3540
}

src/graph/formulas/generators/consumer.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,12 @@ where
2424
{
2525
pub fn try_new(graph: &'a ComponentGraph<N, E>) -> Result<Self, Error> {
2626
Ok(Self {
27-
unvisited_meters: graph.find_all(graph.root_id, |node| node.is_meter(), true)?,
27+
unvisited_meters: graph.find_all(
28+
graph.root_id,
29+
|node| node.is_meter(),
30+
petgraph::Direction::Outgoing,
31+
true,
32+
)?,
2833
graph,
2934
})
3035
}

src/graph/formulas/generators/ev_charger.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,12 @@ where
2929
let ev_charger_ids = if let Some(ev_charger_ids) = ev_charger_ids {
3030
ev_charger_ids
3131
} else {
32-
graph.find_all(graph.root_id, |node| node.is_ev_charger(), false)?
32+
graph.find_all(
33+
graph.root_id,
34+
|node| node.is_ev_charger(),
35+
petgraph::Direction::Outgoing,
36+
false,
37+
)?
3338
};
3439
Ok(Self {
3540
graph,

src/graph/formulas/generators/producer.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ where
4141
|| node.is_pv_inverter()
4242
|| node.is_chp()
4343
},
44+
petgraph::Direction::Outgoing,
4445
false,
4546
)? {
4647
let comp_expr = Self::min_zero(self.graph.fallback_expr([component_id], false)?);

src/graph/formulas/generators/pv.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,12 @@ where
2929
let pv_inverter_ids = if let Some(pv_inverter_ids) = pv_inverter_ids {
3030
pv_inverter_ids
3131
} else {
32-
graph.find_all(graph.root_id, |node| node.is_pv_inverter(), false)?
32+
graph.find_all(
33+
graph.root_id,
34+
|node| node.is_pv_inverter(),
35+
petgraph::Direction::Outgoing,
36+
false,
37+
)?
3338
};
3439
Ok(Self {
3540
graph,

src/graph/retrieval.rs

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -108,14 +108,15 @@ where
108108
}
109109

110110
/// Returns a set of all components that match the given predicate, starting
111-
/// from the component with the given `component_id`.
111+
/// from the component with the given `component_id`, in the given direction.
112112
///
113113
/// If `follow_after_match` is `true`, the search continues deeper beyond
114114
/// the matching components.
115115
pub(crate) fn find_all(
116116
&self,
117117
from: u64,
118118
mut pred: impl FnMut(&N) -> bool,
119+
direction: petgraph::Direction,
119120
follow_after_match: bool,
120121
) -> Result<BTreeSet<u64>, Error> {
121122
let index = self.node_indices.get(&from).ok_or_else(|| {
@@ -133,9 +134,7 @@ where
133134
}
134135
}
135136

136-
let neighbors = self
137-
.graph
138-
.neighbors_directed(index, petgraph::Direction::Outgoing);
137+
let neighbors = self.graph.neighbors_directed(index, direction);
139138
stack.extend(neighbors);
140139
}
141140

@@ -353,37 +352,55 @@ mod tests {
353352
let (components, connections) = nodes_and_edges();
354353
let graph = ComponentGraph::try_new(components.clone(), connections.clone())?;
355354

356-
let found = graph.find_all(graph.root_id, |x| x.is_meter(), false)?;
355+
let found = graph.find_all(
356+
graph.root_id,
357+
|x| x.is_meter(),
358+
petgraph::Direction::Outgoing,
359+
false,
360+
)?;
357361
assert_eq!(found, [2].iter().cloned().collect());
358362

359-
let found = graph.find_all(graph.root_id, |x| x.is_meter(), true)?;
363+
let found = graph.find_all(
364+
graph.root_id,
365+
|x| x.is_meter(),
366+
petgraph::Direction::Outgoing,
367+
true,
368+
)?;
360369
assert_eq!(found, [2, 3, 6].iter().cloned().collect());
361370

362371
let found = graph.find_all(
363372
graph.root_id,
364373
|x| !x.is_grid() && !graph.is_component_meter(x.component_id()).unwrap_or(false),
374+
petgraph::Direction::Outgoing,
365375
true,
366376
)?;
367377
assert_eq!(found, [2, 4, 5, 7, 8].iter().cloned().collect());
368378

369379
let found = graph.find_all(
370380
6,
371381
|x| !x.is_grid() && !graph.is_component_meter(x.component_id()).unwrap_or(false),
382+
petgraph::Direction::Outgoing,
372383
true,
373384
)?;
374385
assert_eq!(found, [7, 8].iter().cloned().collect());
375386

376387
let found = graph.find_all(
377388
graph.root_id,
378389
|x| !x.is_grid() && !graph.is_component_meter(x.component_id()).unwrap_or(false),
390+
petgraph::Direction::Outgoing,
379391
false,
380392
)?;
381393
assert_eq!(found, [2].iter().cloned().collect());
382394

383-
let found = graph.find_all(graph.root_id, |_| true, false)?;
395+
let found = graph.find_all(
396+
graph.root_id,
397+
|_| true,
398+
petgraph::Direction::Outgoing,
399+
false,
400+
)?;
384401
assert_eq!(found, [1].iter().cloned().collect());
385402

386-
let found = graph.find_all(3, |_| true, true)?;
403+
let found = graph.find_all(3, |_| true, petgraph::Direction::Outgoing, true)?;
387404
assert_eq!(found, [3, 4, 5].iter().cloned().collect());
388405

389406
Ok(())

0 commit comments

Comments
 (0)