diff --git a/crates/bevy_ecs/src/relationship/relationship_query.rs b/crates/bevy_ecs/src/relationship/relationship_query.rs index a7acea7de0732..02ec3f920ea52 100644 --- a/crates/bevy_ecs/src/relationship/relationship_query.rs +++ b/crates/bevy_ecs/src/relationship/relationship_query.rs @@ -1,3 +1,4 @@ +use super::SourceIter; use crate::{ entity::Entity, query::{QueryData, QueryFilter}, @@ -5,10 +6,9 @@ use crate::{ system::Query, }; use alloc::collections::VecDeque; +use core::marker::PhantomData; use smallvec::SmallVec; -use super::SourceIter; - impl<'w, 's, D: QueryData, F: QueryFilter> Query<'w, 's, D, F> { /// If the given `entity` contains the `R` [`Relationship`] component, returns the /// target entity of that relationship. @@ -92,7 +92,7 @@ impl<'w, 's, D: QueryData, F: QueryFilter> Query<'w, 's, D, F> { } /// Iterates all descendant entities as defined by the given `entity`'s [`RelationshipTarget`] and their recursive - /// [`RelationshipTarget`]. + /// [`RelationshipTarget`] in breadth-first order. /// /// # Warning /// @@ -101,11 +101,11 @@ impl<'w, 's, D: QueryData, F: QueryFilter> Query<'w, 's, D, F> { pub fn iter_descendants( &'w self, entity: Entity, - ) -> DescendantIter<'w, 's, D, F, S> + ) -> DescendantIter> where D::ReadOnly: QueryData = &'w S>, { - DescendantIter::new(self, entity) + DescendantIter(BreadthFirst::new(self, entity)) } /// Iterates all descendant entities as defined by the given `entity`'s [`RelationshipTarget`] and their recursive @@ -118,12 +118,12 @@ impl<'w, 's, D: QueryData, F: QueryFilter> Query<'w, 's, D, F> { pub fn iter_descendants_depth_first( &'w self, entity: Entity, - ) -> DescendantDepthFirstIter<'w, 's, D, F, S> + ) -> DescendantIter> where D::ReadOnly: QueryData = &'w S>, SourceIter<'w, S>: DoubleEndedIterator, { - DescendantDepthFirstIter::new(self, entity) + DescendantIter(DepthFirst::new(self, entity)) } /// Iterates all ancestors of the given `entity` as defined by the `R` [`Relationship`]. @@ -143,10 +143,10 @@ impl<'w, 's, D: QueryData, F: QueryFilter> Query<'w, 's, D, F> { } } -/// An [`Iterator`] of [`Entity`]s over the descendants of an [`Entity`]. +/// An iteration strategy of [`Entity`]s over the descendants of an [`Entity`]. /// /// Traverses the hierarchy breadth-first. -pub struct DescendantIter<'w, 's, D: QueryData, F: QueryFilter, S: RelationshipTarget> +pub struct BreadthFirst<'w, 's, D: QueryData, F: QueryFilter, S: RelationshipTarget> where D::ReadOnly: QueryData = &'w S>, { @@ -154,13 +154,12 @@ where vecdeque: VecDeque, } -impl<'w, 's, D: QueryData, F: QueryFilter, S: RelationshipTarget> DescendantIter<'w, 's, D, F, S> +impl<'w, 's, D: QueryData, F: QueryFilter, S: RelationshipTarget> BreadthFirst<'w, 's, D, F, S> where D::ReadOnly: QueryData = &'w S>, { - /// Returns a new [`DescendantIter`]. - pub fn new(children_query: &'w Query<'w, 's, D, F>, entity: Entity) -> Self { - DescendantIter { + fn new(children_query: &'w Query<'w, 's, D, F>, entity: Entity) -> Self { + Self { children_query, vecdeque: children_query .get(entity) @@ -171,28 +170,28 @@ where } } -impl<'w, 's, D: QueryData, F: QueryFilter, S: RelationshipTarget> Iterator - for DescendantIter<'w, 's, D, F, S> +impl<'w, 's, D: QueryData, F: QueryFilter, S: RelationshipTarget> DescendantsIterator + for BreadthFirst<'w, 's, D, F, S> where D::ReadOnly: QueryData = &'w S>, { - type Item = Entity; + fn next_node(&mut self) -> Option { + self.vecdeque.pop_front() + } - fn next(&mut self) -> Option { - let entity = self.vecdeque.pop_front()?; + fn set_children(&mut self, node: Entity) { + let Ok(children) = self.children_query.get(node) else { + return; + }; - if let Ok(children) = self.children_query.get(entity) { - self.vecdeque.extend(children.iter()); - } - - Some(entity) + self.vecdeque.extend(children.iter()); } } -/// An [`Iterator`] of [`Entity`]s over the descendants of an [`Entity`]. +/// An iteration strategy of [`Entity`]s over the descendants of an [`Entity`]. /// /// Traverses the hierarchy depth-first. -pub struct DescendantDepthFirstIter<'w, 's, D: QueryData, F: QueryFilter, S: RelationshipTarget> +pub struct DepthFirst<'w, 's, D: QueryData, F: QueryFilter, S: RelationshipTarget> where D::ReadOnly: QueryData = &'w S>, { @@ -200,15 +199,13 @@ where stack: SmallVec<[Entity; 8]>, } -impl<'w, 's, D: QueryData, F: QueryFilter, S: RelationshipTarget> - DescendantDepthFirstIter<'w, 's, D, F, S> +impl<'w, 's, D: QueryData, F: QueryFilter, S: RelationshipTarget> DepthFirst<'w, 's, D, F, S> where D::ReadOnly: QueryData = &'w S>, SourceIter<'w, S>: DoubleEndedIterator, { - /// Returns a new [`DescendantDepthFirstIter`]. - pub fn new(children_query: &'w Query<'w, 's, D, F>, entity: Entity) -> Self { - DescendantDepthFirstIter { + fn new(children_query: &'w Query<'w, 's, D, F>, entity: Entity) -> Self { + Self { children_query, stack: children_query .get(entity) @@ -217,25 +214,163 @@ where } } -impl<'w, 's, D: QueryData, F: QueryFilter, S: RelationshipTarget> Iterator - for DescendantDepthFirstIter<'w, 's, D, F, S> +impl<'w, 's, D: QueryData, F: QueryFilter, S: RelationshipTarget> DescendantsIterator + for DepthFirst<'w, 's, D, F, S> where D::ReadOnly: QueryData = &'w S>, SourceIter<'w, S>: DoubleEndedIterator, +{ + fn next_node(&mut self) -> Option { + self.stack.pop() + } + + fn set_children(&mut self, node: Entity) { + let Ok(children) = self.children_query.get(node) else { + return; + }; + + self.stack.extend(children.iter().rev()); + } +} + +/// An [`Iterator`] of [`Entity`]s over the descendants of an [`Entity`]. +/// +/// Concrete traversal strategy depends on the `Traversal` type. +pub struct DescendantIter(Traversal); + +impl DescendantIter { + /// Creates an iterator which uses a closure to determine if recursive [`RelationshipTarget`]s + /// should be yielded. + /// + /// Once the provided closure returns `false` for an [`Entity`] it and its recursive + /// [`RelationshipTarget`]s will not be yielded, effectively skipping the sub hierarchy where + /// that [`Entity`] is the root. + pub fn filter_hierarchies(self, predicate: F) -> FilterHierarchies + where + F: FnMut(&Entity) -> bool, + { + FilterHierarchies { + iter: self, + predicate, + } + } + + /// Creates an iterator which uses a closure to both filter and map over recursive + /// [`RelationshipTarget`]s. + /// + /// Once the provided closure returns `None` for an [`Entity`] the mapped values for + /// it and its recursive [`RelationshipTarget`]s will not be yielded, effectively skipping the + /// sub hierarchy where that [`Entity`] is the root. + pub fn filter_map_hierarchies(self, map: F) -> FilterMapHierarchies + where + F: FnMut(Entity) -> Option, + { + FilterMapHierarchies { + iter: self, + map, + _p: PhantomData, + } + } +} + +impl Iterator for DescendantIter +where + Traversal: DescendantsIterator, +{ + type Item = Entity; + + fn next(&mut self) -> Option { + let next_root = self.0.next_node()?; + self.0.set_children(next_root); + + Some(next_root) + } +} + +impl DescendantsIterator for DescendantIter +where + Traversal: DescendantsIterator, +{ + fn next_node(&mut self) -> Option { + self.0.next_node() + } + + fn set_children(&mut self, node: Entity) { + self.0.set_children(node); + } +} + +/// An [`Iterator`] of [`Entity`]s over the descendants of an [`Entity`]. +/// +/// Allows conditional skipping of sub hierarchies. +pub struct FilterHierarchies { + iter: T, + predicate: F, +} + +impl Iterator for FilterHierarchies +where + T: DescendantsIterator, + F: FnMut(&Entity) -> bool, { type Item = Entity; fn next(&mut self) -> Option { - let entity = self.stack.pop()?; + let mut node; - if let Ok(children) = self.children_query.get(entity) { - self.stack.extend(children.iter().rev()); + loop { + node = self.iter.next_node()?; + if (self.predicate)(&node) { + break; + } } + self.iter.set_children(node); - Some(entity) + Some(node) } } +/// An [`Iterator`] of [`Entity`]s over the descendants of an [`Entity`]. +/// +/// Allows conditional skipping of sub hierarchies. +pub struct FilterMapHierarchies { + iter: T, + map: F, + _p: PhantomData, +} + +impl Iterator for FilterMapHierarchies +where + T: DescendantsIterator, + F: FnMut(Entity) -> Option, +{ + type Item = R; + + fn next(&mut self) -> Option { + let mut node; + let mut value; + + loop { + node = self.iter.next_node()?; + value = (self.map)(node); + if value.is_some() { + break; + } + } + self.iter.set_children(node); + + value + } +} + +/// A trait to implement a concrete descendant traversal strategy +/// +/// Used to streamline breadth-first and depth-first iteration +trait DescendantsIterator { + fn next_node(&mut self) -> Option; + fn set_children(&mut self, node: Entity); +} + /// An [`Iterator`] of [`Entity`]s over the ancestors of an [`Entity`]. pub struct AncestorIter<'w, 's, D: QueryData, F: QueryFilter, R: Relationship> where @@ -270,3 +405,193 @@ where self.next } } + +#[cfg(test)] +mod test_iter_descendants { + use crate::{ + prelude::*, + system::{RunSystemError, RunSystemOnce}, + }; + use alloc::{vec, vec::Vec}; + + mod iter_descendants_breadth_first { + use super::*; + + #[test] + fn iter_all() -> Result<(), RunSystemError> { + let mut world = World::new(); + let root = world.spawn_empty().id(); + let a = world.spawn(ChildOf(root)).id(); + let aa = world.spawn(ChildOf(a)).id(); + let ab = world.spawn(ChildOf(a)).id(); + let b = world.spawn(ChildOf(root)).id(); + let ba = world.spawn(ChildOf(b)).id(); + let bb = world.spawn(ChildOf(b)).id(); + + let descendants = world.run_system_once(move |c: Query<&Children>| { + c.iter_descendants(root).collect::>() + })?; + + assert_eq!(descendants, vec![a, b, aa, ab, ba, bb]); + Ok(()) + } + } + + mod iter_descendants_depth_first { + use super::*; + + #[test] + fn iter_all() -> Result<(), RunSystemError> { + let mut world = World::new(); + let root = world.spawn_empty().id(); + let a = world.spawn(ChildOf(root)).id(); + let aa = world.spawn(ChildOf(a)).id(); + let ab = world.spawn(ChildOf(a)).id(); + let b = world.spawn(ChildOf(root)).id(); + let ba = world.spawn(ChildOf(b)).id(); + let bb = world.spawn(ChildOf(b)).id(); + + let descendants = world.run_system_once(move |c: Query<&Children>| { + c.iter_descendants_depth_first(root).collect::>() + })?; + + assert_eq!(descendants, vec![a, aa, ab, b, ba, bb]); + Ok(()) + } + } + + mod filter_hierarchies { + use super::*; + + #[test] + fn iter_all() -> Result<(), RunSystemError> { + let mut world = World::new(); + let root = world.spawn_empty().id(); + let children = vec![ + world.spawn(ChildOf(root)).id(), + world.spawn(ChildOf(root)).id(), + world.spawn(ChildOf(root)).id(), + ]; + + let descendants = world.run_system_once(move |c: Query<&Children>| { + c.iter_descendants(root) + .filter_hierarchies(|_| true) + .collect::>() + })?; + + assert_eq!(descendants, children); + Ok(()) + } + + #[test] + fn skip_entity_when_flat() -> Result<(), RunSystemError> { + let mut world = World::new(); + let root = world.spawn_empty().id(); + let a = world.spawn(ChildOf(root)).id(); + let skip = world.spawn(ChildOf(root)).id(); + let b = world.spawn(ChildOf(root)).id(); + + let descendants = world.run_system_once(move |c: Query<&Children>| { + c.iter_descendants(root) + .filter_hierarchies(|e| e != &skip) + .collect::>() + })?; + + assert_eq!(descendants, vec![a, b]); + Ok(()) + } + + #[test] + fn skip_sub_hierarchy() -> Result<(), RunSystemError> { + let mut world = World::new(); + let root = world.spawn_empty().id(); + let a = world.spawn(ChildOf(root)).id(); + let skip = world.spawn((ChildOf(root), children![(), ()])).id(); + let b = world.spawn(ChildOf(root)).id(); + + let descendants = world.run_system_once(move |c: Query<&Children>| { + c.iter_descendants(root) + .filter_hierarchies(|e| e != &skip) + .collect::>() + })?; + + assert_eq!(descendants, vec![a, b]); + Ok(()) + } + } + + mod map_hierarchies { + use super::*; + + #[test] + fn iter_all() -> Result<(), RunSystemError> { + let mut world = World::new(); + let root = world + .spawn(children![Name::from("a"), Name::from("b"), Name::from("c")]) + .id(); + + let names = world.run_system_once(move |c: Query<&Children>, n: Query<&Name>| { + c.iter_descendants(root) + .filter_map_hierarchies(|e| n.get(e).ok().cloned()) + .collect::>() + })?; + + assert_eq!( + names, + vec![Name::from("a"), Name::from("b"), Name::from("c")] + ); + Ok(()) + } + + #[test] + fn skip_entity_when_flat() -> Result<(), RunSystemError> { + let mut world = World::new(); + let root = world + .spawn(children![ + Name::from("a"), + Name::from("skip"), + Name::from("b"), + ]) + .id(); + + let names = world.run_system_once(move |c: Query<&Children>, n: Query<&Name>| { + c.iter_descendants(root) + .filter_map_hierarchies(|e| match n.get(e) { + Ok(name) if name.as_str() != "skip" => Some(name.clone()), + _ => None, + }) + .collect::>() + })?; + + assert_eq!(names, vec![Name::from("a"), Name::from("b")]); + Ok(()) + } + + #[test] + fn skip_sub_hierarchy() -> Result<(), RunSystemError> { + let mut world = World::new(); + let root = world + .spawn(children![ + Name::from("a"), + ( + Name::from("skip"), + children![Name::from("skip child a"), Name::from("skip child b")] + ), + Name::from("b"), + ]) + .id(); + + let names = world.run_system_once(move |c: Query<&Children>, n: Query<&Name>| { + c.iter_descendants(root) + .filter_map_hierarchies(|e| match n.get(e) { + Ok(name) if name.as_str() != "skip" => Some(name.clone()), + _ => None, + }) + .collect::>() + })?; + + assert_eq!(names, vec![Name::from("a"), Name::from("b")]); + Ok(()) + } + } +}