|
| 1 | +use std::{ |
| 2 | + any::{Any, TypeId}, |
| 3 | + collections::HashMap, |
| 4 | + sync::Arc, |
| 5 | +}; |
| 6 | + |
| 7 | +use bitwarden_error::bitwarden_error; |
| 8 | +use thiserror::Error; |
| 9 | + |
| 10 | +use super::repository::RepositoryItemRegistration; |
| 11 | +use crate::repository::{Repository, RepositoryItem}; |
| 12 | + |
| 13 | +/// A registry that contains repositories for different types of items. |
| 14 | +/// These repositories can be either managed by the client or by the SDK itself. |
| 15 | +pub struct StateRegistry { |
| 16 | + client_managed: HashMap<TypeId, Box<dyn Any + Send + Sync>>, |
| 17 | +} |
| 18 | + |
| 19 | +impl std::fmt::Debug for StateRegistry { |
| 20 | + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
| 21 | + f.debug_struct("StateRegistry") |
| 22 | + .field("client_managed", &self.client_managed.keys()) |
| 23 | + .finish() |
| 24 | + } |
| 25 | +} |
| 26 | + |
| 27 | +#[allow(missing_docs)] |
| 28 | +#[bitwarden_error(flat)] |
| 29 | +#[derive(Debug, Error)] |
| 30 | +pub enum StateRegistryError { |
| 31 | + #[error("Repository for type {0} is not registered as client-managed")] |
| 32 | + RepositoryNotClientManaged(&'static str), |
| 33 | + |
| 34 | + #[error("Not all client-managed repositories are registered")] |
| 35 | + NotAllRepositoriesRegistered, |
| 36 | +} |
| 37 | + |
| 38 | +impl StateRegistry { |
| 39 | + /// Creates a new empty `StateRegistry`. |
| 40 | + #[allow(clippy::new_without_default)] |
| 41 | + pub fn new() -> Self { |
| 42 | + StateRegistry { |
| 43 | + client_managed: HashMap::new(), |
| 44 | + } |
| 45 | + } |
| 46 | + |
| 47 | + /// Registers a client-managed repository into the map, associating it with its type. |
| 48 | + pub fn register_client_managed<T: RepositoryItem>( |
| 49 | + &mut self, |
| 50 | + value: Arc<dyn Repository<T>>, |
| 51 | + ) -> Result<(), StateRegistryError> { |
| 52 | + let mut possible_registrations = RepositoryItemRegistration::iter(); |
| 53 | + match possible_registrations.find(|reg| reg.type_id() == TypeId::of::<T>()) { |
| 54 | + Some(reg) => { |
| 55 | + if !reg.rtype.is_client_managed() { |
| 56 | + return Err(StateRegistryError::RepositoryNotClientManaged(reg.name)); |
| 57 | + } |
| 58 | + } |
| 59 | + // This should never happen, as we have tests to ensure all repositories are registered. |
| 60 | + _ => { |
| 61 | + return Err(StateRegistryError::NotAllRepositoriesRegistered); |
| 62 | + } |
| 63 | + } |
| 64 | + |
| 65 | + self.client_managed |
| 66 | + .insert(TypeId::of::<T>(), Box::new(value)); |
| 67 | + |
| 68 | + Ok(()) |
| 69 | + } |
| 70 | + |
| 71 | + /// Retrieves a client-managed repository from the map given its type. |
| 72 | + pub fn get_client_managed<T: RepositoryItem>(&self) -> Option<Arc<dyn Repository<T>>> { |
| 73 | + self.client_managed |
| 74 | + .get(&TypeId::of::<T>()) |
| 75 | + .and_then(|boxed| boxed.downcast_ref::<Arc<dyn Repository<T>>>()) |
| 76 | + .map(Arc::clone) |
| 77 | + } |
| 78 | + |
| 79 | + /// Validates that all repositories registered in the client-managed state registry. |
| 80 | + /// This should only be called after all the repositories have been registered by the clients. |
| 81 | + pub fn validate_repositories(&self) -> Result<(), StateRegistryError> { |
| 82 | + let possible_registrations = RepositoryItemRegistration::iter(); |
| 83 | + let mut missing_repository = false; |
| 84 | + |
| 85 | + for reg in possible_registrations { |
| 86 | + if reg.rtype.is_client_managed() && !self.client_managed.contains_key(®.type_id()) { |
| 87 | + log::error!( |
| 88 | + "Repository for type {} is not registered in the client-managed state registry", |
| 89 | + reg.name |
| 90 | + ); |
| 91 | + missing_repository = true; |
| 92 | + } |
| 93 | + } |
| 94 | + |
| 95 | + if missing_repository { |
| 96 | + return Err(StateRegistryError::NotAllRepositoriesRegistered); |
| 97 | + } |
| 98 | + |
| 99 | + Ok(()) |
| 100 | + } |
| 101 | +} |
| 102 | + |
| 103 | +#[cfg(test)] |
| 104 | +mod tests { |
| 105 | + use super::*; |
| 106 | + use crate::{ |
| 107 | + register_repository_item, |
| 108 | + repository::{RepositoryError, RepositoryItem}, |
| 109 | + }; |
| 110 | + |
| 111 | + macro_rules! impl_repository { |
| 112 | + ($name:ident, $ty:ty) => { |
| 113 | + #[async_trait::async_trait] |
| 114 | + impl Repository<$ty> for $name { |
| 115 | + async fn get(&self, _key: String) -> Result<Option<$ty>, RepositoryError> { |
| 116 | + Ok(Some(TestItem(self.0.clone()))) |
| 117 | + } |
| 118 | + async fn list(&self) -> Result<Vec<$ty>, RepositoryError> { |
| 119 | + unimplemented!() |
| 120 | + } |
| 121 | + async fn set(&self, _key: String, _value: $ty) -> Result<(), RepositoryError> { |
| 122 | + unimplemented!() |
| 123 | + } |
| 124 | + async fn remove(&self, _key: String) -> Result<(), RepositoryError> { |
| 125 | + unimplemented!() |
| 126 | + } |
| 127 | + } |
| 128 | + }; |
| 129 | + } |
| 130 | + |
| 131 | + #[derive(PartialEq, Eq, Debug)] |
| 132 | + struct TestA(usize); |
| 133 | + #[derive(PartialEq, Eq, Debug)] |
| 134 | + struct TestB(String); |
| 135 | + #[derive(PartialEq, Eq, Debug)] |
| 136 | + struct TestC(Vec<u8>); |
| 137 | + #[derive(PartialEq, Eq, Debug)] |
| 138 | + struct TestItem<T>(T); |
| 139 | + |
| 140 | + register_repository_item!(TestItem<usize>, "TestItem<usize>", ClientManaged); |
| 141 | + register_repository_item!(TestItem<String>, "TestItem<String>", ClientManaged); |
| 142 | + register_repository_item!(TestItem<Vec<u8>>, "TestItem<Vec<u8>>", ClientManaged); |
| 143 | + |
| 144 | + impl_repository!(TestA, TestItem<usize>); |
| 145 | + impl_repository!(TestB, TestItem<String>); |
| 146 | + impl_repository!(TestC, TestItem<Vec<u8>>); |
| 147 | + |
| 148 | + #[tokio::test] |
| 149 | + async fn test_repository_map() { |
| 150 | + let a = Arc::new(TestA(145832)); |
| 151 | + let b = Arc::new(TestB("test".to_string())); |
| 152 | + let c = Arc::new(TestC(vec![1, 2, 3, 4, 5, 6, 7, 8, 9])); |
| 153 | + |
| 154 | + let mut map = StateRegistry::new(); |
| 155 | + |
| 156 | + async fn get<T: RepositoryItem>(map: &StateRegistry) -> Option<T> { |
| 157 | + map.get_client_managed::<T>() |
| 158 | + .unwrap() |
| 159 | + .get(String::new()) |
| 160 | + .await |
| 161 | + .unwrap() |
| 162 | + } |
| 163 | + |
| 164 | + assert!(map.get_client_managed::<TestItem<usize>>().is_none()); |
| 165 | + assert!(map.get_client_managed::<TestItem<String>>().is_none()); |
| 166 | + assert!(map.get_client_managed::<TestItem<Vec<u8>>>().is_none()); |
| 167 | + |
| 168 | + map.register_client_managed(a.clone()).unwrap(); |
| 169 | + assert_eq!(get(&map).await, Some(TestItem(a.0))); |
| 170 | + assert!(map.get_client_managed::<TestItem<String>>().is_none()); |
| 171 | + assert!(map.get_client_managed::<TestItem<Vec<u8>>>().is_none()); |
| 172 | + |
| 173 | + map.register_client_managed(b.clone()).unwrap(); |
| 174 | + assert_eq!(get(&map).await, Some(TestItem(a.0))); |
| 175 | + assert_eq!(get(&map).await, Some(TestItem(b.0.clone()))); |
| 176 | + assert!(map.get_client_managed::<TestItem<Vec<u8>>>().is_none()); |
| 177 | + |
| 178 | + map.register_client_managed(c.clone()).unwrap(); |
| 179 | + assert_eq!(get(&map).await, Some(TestItem(a.0))); |
| 180 | + assert_eq!(get(&map).await, Some(TestItem(b.0.clone()))); |
| 181 | + assert_eq!(get(&map).await, Some(TestItem(c.0.clone()))); |
| 182 | + } |
| 183 | +} |
0 commit comments