diff --git a/src/linked_hash_map.rs b/src/linked_hash_map.rs index b27c98b..f608fec 100644 --- a/src/linked_hash_map.rs +++ b/src/linked_hash_map.rs @@ -531,6 +531,93 @@ where } } } + + /// Returns iterator that was skipped to a specific key entry, or None. + /// It does not implement `ExactSizeIterator` because + /// it is unclear where exactly the iterator is. + /// + /// It is useful when iterating over a subset of + /// all items in order, e.g. for starting a queue iteration at a specific key + /// + /// # Examples + /// + /// ```rs + /// let mut map = LinkedHashMap::new(); + /// + /// map.insert("a", 10); + /// map.insert("b", 20); + /// map.insert("c", 30); + /// + /// assert_eq!(map.iter_at_key(&"e").is_none(), true); + /// + /// let mut iter = map.iter_at_key(&"b").unwrap(); + /// assert_eq!((&"b", &20), iter.next().unwrap()); + /// assert_eq!((&"c", &30), iter.next().unwrap()); + /// assert_eq!(None, iter.next()); + /// assert_eq!(None, iter.next()); + /// ``` + /// + #[inline] + pub fn iter_at_key(&self, k: &K) -> Option> { + let tail = unsafe { self.values?.as_ref().links.value.prev }; + + let hash = hash_key(&self.hash_builder, k); + let node = unsafe { + *self + .map + .raw_entry() + .from_hash(hash, move |key| k.eq((*key).as_ref().key_ref()))? + .0 + }; + Some(IterAtKey { + tail: tail.as_ptr(), + cur: node.as_ptr(), + marker: PhantomData, + }) + } + + /// Returns a mutable iterator that was skipped to a specific key entry, or None. + /// It does not implement `ExactSizeIterator` because + /// it is unclear where exactly the iterator is. + /// + /// It is useful when iterating over a subset of + /// all items in order, e.g. for starting a queue iteration at a specific key + /// + /// # Examples + /// + /// ```rs + /// let mut map = LinkedHashMap::new(); + /// map.insert("a", 10); + /// map.insert("c", 30); + /// map.insert("b", 20); + /// map.insert("d", 40); + /// + /// assert_eq!(map.iter_at_key_mut(&"e").is_none(), true); + /// + /// let mut iter = map.iter_at_key_mut(&"c").unwrap(); + /// let entry = iter.next().unwrap(); + /// assert_eq!("c", *entry.0); + /// *entry.1 = 17; + /// + /// assert_eq!(format!("{:?}", iter), "[(\"b\", 20), (\"d\", 40)]"); + /// assert_eq!(17, map[&"c"]); + /// ``` + /// + #[inline] + pub fn iter_at_key_mut(&mut self, k: &K) -> Option> { + let tail = unsafe { self.values?.as_ref().links.value.prev }; + match self.raw_entry_mut().from_key(k) { + RawEntryMut::Occupied(entry) => { + let cur = entry.entry.key(); + Some(IterAtKeyMut { + tail: Some(tail), + cur: Some(*cur), + marker: PhantomData, + }) + } + RawEntryMut::Vacant(_) => None, + } + } } impl LinkedHashMap @@ -1384,6 +1471,18 @@ pub struct Drain<'a, K, V> { marker: PhantomData<(K, V, &'a LinkedHashMap)>, } +pub struct IterAtKey<'a, K, V> { + tail: *const Node, + cur: *const Node, + marker: PhantomData<(&'a K, &'a V)>, +} + +pub struct IterAtKeyMut<'a, K, V> { + tail: Option>>, + cur: Option>>, + marker: PhantomData<(&'a K, &'a mut V)>, +} + impl IterMut<'_, K, V> { #[inline] pub(crate) fn iter(&self) -> Iter<'_, K, V> { @@ -1420,6 +1519,17 @@ impl Drain<'_, K, V> { } } +impl IterAtKeyMut<'_, K, V> { + #[inline] + pub(crate) fn iter(&self) -> IterAtKey<'_, K, V> { + IterAtKey { + tail: self.tail.as_ptr(), + cur: self.cur.as_ptr(), + marker: PhantomData, + } + } +} + unsafe impl<'a, K, V> Send for Iter<'a, K, V> where K: Send, @@ -1448,6 +1558,20 @@ where { } +unsafe impl<'a, K, V> Send for IterAtKey<'a, K, V> +where + K: Send, + V: Send, +{ +} + +unsafe impl<'a, K, V> Send for IterAtKeyMut<'a, K, V> +where + K: Send, + V: Send, +{ +} + unsafe impl<'a, K, V> Sync for Iter<'a, K, V> where K: Sync, @@ -1476,6 +1600,20 @@ where { } +unsafe impl<'a, K, V> Sync for IterAtKey<'a, K, V> +where + K: Sync, + V: Sync, +{ +} + +unsafe impl<'a, K, V> Sync for IterAtKeyMut<'a, K, V> +where + K: Sync, + V: Sync, +{ +} + impl<'a, K, V> Clone for Iter<'a, K, V> { #[inline] fn clone(&self) -> Self { @@ -1483,6 +1621,17 @@ impl<'a, K, V> Clone for Iter<'a, K, V> { } } +impl<'a, K, V> Clone for IterAtKey<'a, K, V> { + #[inline] + fn clone(&self) -> Self { + IterAtKey { + tail: self.tail, + cur: self.cur, + marker: PhantomData, + } + } +} + impl fmt::Debug for Iter<'_, K, V> { #[inline] fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -1523,6 +1672,28 @@ where } } +impl fmt::Debug for IterAtKey<'_, K, V> +where + K: fmt::Debug, + V: fmt::Debug, +{ + #[inline] + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_list().entries(self.clone()).finish() + } +} + +impl fmt::Debug for IterAtKeyMut<'_, K, V> +where + K: fmt::Debug, + V: fmt::Debug, +{ + #[inline] + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_list().entries(self.iter()).finish() + } +} + impl<'a, K, V> Iterator for Iter<'a, K, V> { type Item = (&'a K, &'a V); @@ -1617,6 +1788,47 @@ impl<'a, K, V> Iterator for Drain<'a, K, V> { } } +impl<'a, K, V> Iterator for IterAtKey<'a, K, V> { + type Item = (&'a K, &'a V); + + #[inline] + fn next(&mut self) -> Option<(&'a K, &'a V)> { + if self.cur.is_null() { + return None; + } + unsafe { + let last_iter = self.cur == self.tail; + let (key, value) = (*self.cur).entry_ref(); + self.cur = (*self.cur).links.value.next.as_ptr(); + if last_iter { + self.cur = std::ptr::null(); + } + Some((key, value)) + } + } +} + +impl<'a, K, V> Iterator for IterAtKeyMut<'a, K, V> { + type Item = (&'a K, &'a mut V); + + #[inline] + fn next(&mut self) -> Option<(&'a K, &'a mut V)> { + if self.cur.is_none() { + None + } else { + unsafe { + let last_iter = self.cur == self.tail; + let (key, value) = (*self.cur.as_ptr()).entry_mut(); + self.cur = Some((*self.cur.as_ptr()).links.value.next); + if last_iter { + self.cur = None; + } + Some((key, value)) + } + } + } +} + impl<'a, K, V> DoubleEndedIterator for Iter<'a, K, V> { #[inline] fn next_back(&mut self) -> Option<(&'a K, &'a V)> { @@ -1683,6 +1895,42 @@ impl<'a, K, V> DoubleEndedIterator for Drain<'a, K, V> { } } +impl<'a, K, V> DoubleEndedIterator for IterAtKey<'a, K, V> { + #[inline] + fn next_back(&mut self) -> Option<(&'a K, &'a V)> { + if self.cur.is_null() { + None + } else { + unsafe { + if self.cur == self.tail { + self.cur = std::ptr::null(); + } + let (key, value) = (*self.tail).entry_ref(); + self.tail = (*self.tail).links.value.prev.as_ptr(); + Some((key, value)) + } + } + } +} + +impl<'a, K, V> DoubleEndedIterator for IterAtKeyMut<'a, K, V> { + #[inline] + fn next_back(&mut self) -> Option<(&'a K, &'a mut V)> { + if self.cur.is_none() { + None + } else { + unsafe { + if self.cur == self.tail { + self.cur = None; + } + let (key, value) = (*self.tail.as_ptr()).entry_mut(); + self.tail = Some((*self.tail.as_ptr()).links.value.prev); + Some((key, value)) + } + } + } +} + impl<'a, K, V> ExactSizeIterator for Iter<'a, K, V> {} impl<'a, K, V> ExactSizeIterator for IterMut<'a, K, V> {} diff --git a/src/linked_hash_set.rs b/src/linked_hash_set.rs index 5a89875..111c987 100644 --- a/src/linked_hash_set.rs +++ b/src/linked_hash_set.rs @@ -305,6 +305,38 @@ where { self.map.retain_with_order(|k, _| f(k)); } + + /// Returns iterator that was skipped to a specific key entry, or None. + /// It does not implement `ExactSizeIterator` because + /// it is unclear where exactly the iterator is. + /// + /// It is useful when iterating over a subset of + /// all items in order, e.g. for starting a queue iteration at a specific key + /// + /// # Examples + /// + /// ```rs + /// let mut map = LinkedHashSet::new(); + /// + /// map.insert("a"); + /// map.insert("b"); + /// map.insert("c"); + /// + /// assert_eq!(map.iter_at_key(&"e").is_none(), true); + /// + /// // regular iter + /// let mut iter = map.iter_at_key(&"b").unwrap(); + /// assert_eq!(&"b", iter.next().unwrap()); + /// assert_eq!(&"c", iter.next().unwrap()); + /// assert_eq!(None, iter.next()); + /// assert_eq!(None, iter.next()); + /// ``` + /// + #[inline] + pub fn iter_at_key(&self, k: &T) -> Option> { + let iter = self.map.iter_at_key(k)?; + Some(IterAtKey { iter }) + } } impl Clone for LinkedHashSet { @@ -467,6 +499,10 @@ pub struct Drain<'a, K: 'a> { iter: linked_hash_map::Drain<'a, K, ()>, } +pub struct IterAtKey<'a, K> { + iter: linked_hash_map::IterAtKey<'a, K, ()>, +} + pub struct Intersection<'a, T, S> { iter: Iter<'a, T>, other: &'a LinkedHashSet, @@ -601,6 +637,38 @@ impl<'a, T, S> Clone for Intersection<'a, T, S> { } } +impl<'a, K> Clone for IterAtKey<'a, K> { + #[inline] + fn clone(&self) -> IterAtKey<'a, K> { + IterAtKey { + iter: self.iter.clone(), + } + } +} + +impl<'a, K> Iterator for IterAtKey<'a, K> { + type Item = &'a K; + + #[inline] + fn next(&mut self) -> Option<&'a K> { + Some(self.iter.next()?.0) + } +} + +impl<'a, K> DoubleEndedIterator for IterAtKey<'a, K> { + #[inline] + fn next_back(&mut self) -> Option<&'a K> { + Some(self.iter.next_back()?.0) + } +} + +impl<'a, K: fmt::Debug> fmt::Debug for IterAtKey<'a, K> { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_list().entries(self.clone()).finish() + } +} + impl<'a, T, S> Iterator for Intersection<'a, T, S> where T: Eq + Hash, diff --git a/tests/linked_hash_map.rs b/tests/linked_hash_map.rs index e046292..bf62844 100644 --- a/tests/linked_hash_map.rs +++ b/tests/linked_hash_map.rs @@ -446,6 +446,78 @@ fn test_drain() { assert_eq!(c.get(), 3); } +#[test] +fn test_iter_at_key() { + let mut map = LinkedHashMap::new(); + + map.insert("a", 10); + map.insert("b", 20); + map.insert("c", 30); + + assert_eq!(map.iter_at_key(&"e").is_none(), true); + + // regular iter + let mut iter = map.iter_at_key(&"b").unwrap(); + assert_eq!((&"b", &20), iter.next().unwrap()); + assert_eq!((&"c", &30), iter.next().unwrap()); + assert_eq!(None, iter.next()); + assert_eq!(None, iter.next()); + + let mut iter = map.iter_at_key(&"b").unwrap(); + assert_eq!((&"b", &20), iter.next().unwrap()); + let mut iclone = iter.clone(); + assert_eq!((&"c", &30), iter.next().unwrap()); + assert_eq!((&"c", &30), iclone.next().unwrap()); + + // reversed iter + let mut rev_iter = map.iter_at_key(&"b").unwrap().rev(); + assert_eq!((&"c", &30), rev_iter.next().unwrap()); + assert_eq!((&"b", &20), rev_iter.next().unwrap()); + assert_eq!(None, rev_iter.next()); + assert_eq!(None, rev_iter.next()); + + // mixed + let mut mixed_iter = map.iter_at_key(&"b").unwrap(); + assert_eq!((&"b", &20), mixed_iter.next().unwrap()); + assert_eq!((&"c", &30), mixed_iter.next_back().unwrap()); + assert_eq!(None, mixed_iter.next()); + assert_eq!(None, mixed_iter.next_back()); +} + +#[test] +fn test_iter_at_key_mut() { + let mut map = LinkedHashMap::new(); + map.insert("a", 10); + map.insert("c", 30); + map.insert("b", 20); + map.insert("d", 40); + + { + assert_eq!(map.iter_at_key_mut(&"e").is_none(), true); + + let mut iter = map.iter_at_key_mut(&"c").unwrap(); + let entry = iter.next().unwrap(); + assert_eq!("c", *entry.0); + *entry.1 = 17; + + assert_eq!(format!("{:?}", iter), "[(\"b\", 20), (\"d\", 40)]"); + + // reverse iterator + let mut iter = iter.rev(); + let entry = iter.next().unwrap(); + assert_eq!("d", *entry.0); + *entry.1 = 23; + + let entry = iter.next().unwrap(); + assert_eq!("b", *entry.0); + assert_eq!(None, iter.next()); + assert_eq!(None, iter.next()); + } + + assert_eq!(17, map[&"c"]); + assert_eq!(23, map[&"d"]); +} + #[test] fn test_send_sync() { fn is_send_sync() {} @@ -458,6 +530,8 @@ fn test_send_sync() { is_send_sync::>(); is_send_sync::>(); is_send_sync::>(); + is_send_sync::>(); + is_send_sync::>(); is_send_sync::>(); is_send_sync::>(); } diff --git a/tests/linked_hash_set.rs b/tests/linked_hash_set.rs index 7a9e33f..d73080b 100644 --- a/tests/linked_hash_set.rs +++ b/tests/linked_hash_set.rs @@ -146,6 +146,44 @@ fn test_iterate() { assert_eq!(observed, 0xFFFF_FFFF); } +#[test] +fn test_iter_at_key() { + let mut map = LinkedHashSet::new(); + + map.insert("a"); + map.insert("b"); + map.insert("c"); + + assert_eq!(map.iter_at_key(&"e").is_none(), true); + + // regular iter + let mut iter = map.iter_at_key(&"b").unwrap(); + assert_eq!(&"b", iter.next().unwrap()); + assert_eq!(&"c", iter.next().unwrap()); + assert_eq!(None, iter.next()); + assert_eq!(None, iter.next()); + + let mut iter = map.iter_at_key(&"b").unwrap(); + assert_eq!(&"b", iter.next().unwrap()); + let mut iclone = iter.clone(); + assert_eq!(&"c", iter.next().unwrap()); + assert_eq!(&"c", iclone.next().unwrap()); + + // reversed iter + let mut rev_iter = map.iter_at_key(&"b").unwrap().rev(); + assert_eq!(&"c", rev_iter.next().unwrap()); + assert_eq!(&"b", rev_iter.next().unwrap()); + assert_eq!(None, rev_iter.next()); + assert_eq!(None, rev_iter.next()); + + // mixed + let mut mixed_iter = map.iter_at_key(&"b").unwrap(); + assert_eq!(&"b", mixed_iter.next().unwrap()); + assert_eq!(&"c", mixed_iter.next_back().unwrap()); + assert_eq!(None, mixed_iter.next()); + assert_eq!(None, mixed_iter.next_back()); +} + #[test] fn test_intersection() { let mut a = LinkedHashSet::new();