From 8e80101adbf031ddbea46549217ae688eae8ad7c Mon Sep 17 00:00:00 2001 From: LJ Date: Fri, 25 Apr 2025 11:00:11 -0700 Subject: [PATCH] refactor(auth-ref): make `AuthEntryReference` typed --- src/base/spec.rs | 60 ++++++++++++++++++++++++++++++++++++-- src/ops/storages/neo4j.rs | 8 ++--- src/setup/auth_registry.rs | 2 +- 3 files changed, 63 insertions(+), 7 deletions(-) diff --git a/src/base/spec.rs b/src/base/spec.rs index 5b23d37f8..3b63ab4b7 100644 --- a/src/base/spec.rs +++ b/src/base/spec.rs @@ -296,7 +296,63 @@ pub struct SimpleSemanticsQueryHandlerSpec { pub default_similarity_metric: VectorSimilarityMetric, } -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] -pub struct AuthEntryReference { +pub struct AuthEntryReference { pub key: String, + _phantom: std::marker::PhantomData, +} + +impl std::fmt::Debug for AuthEntryReference { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "AuthEntryReference({})", self.key) + } +} + +impl std::fmt::Display for AuthEntryReference { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "AuthEntryReference({})", self.key) + } +} + +impl Clone for AuthEntryReference { + fn clone(&self) -> Self { + Self { + key: self.key.clone(), + _phantom: std::marker::PhantomData, + } + } +} + +impl Serialize for AuthEntryReference { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + self.key.serialize(serializer) + } +} + +impl<'de, T> Deserialize<'de> for AuthEntryReference { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + Ok(Self { + key: String::deserialize(deserializer)?, + _phantom: std::marker::PhantomData, + }) + } +} + +impl PartialEq for AuthEntryReference { + fn eq(&self, other: &Self) -> bool { + self.key == other.key + } +} + +impl Eq for AuthEntryReference {} + +impl std::hash::Hash for AuthEntryReference { + fn hash(&self, state: &mut H) { + self.key.hash(state); + } } diff --git a/src/ops/storages/neo4j.rs b/src/ops/storages/neo4j.rs index cd5b09de1..e028574aa 100644 --- a/src/ops/storages/neo4j.rs +++ b/src/ops/storages/neo4j.rs @@ -24,13 +24,13 @@ pub struct ConnectionSpec { #[derive(Debug, Deserialize)] pub struct Spec { - connection: spec::AuthEntryReference, + connection: spec::AuthEntryReference, mapping: GraphElementMapping, } #[derive(Debug, Deserialize)] pub struct Declaration { - connection: spec::AuthEntryReference, + connection: spec::AuthEntryReference, #[serde(flatten)] decl: GraphDeclarations, } @@ -92,7 +92,7 @@ impl std::fmt::Display for ElementType { #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] pub struct GraphElement { - connection: AuthEntryReference, + connection: AuthEntryReference, typ: ElementType, } @@ -156,7 +156,7 @@ struct AnalyzedNodeLabelInfo { } pub struct ExportContext { - connection_ref: AuthEntryReference, + connection_ref: AuthEntryReference, graph: Arc, create_order: u8, diff --git a/src/setup/auth_registry.rs b/src/setup/auth_registry.rs index 9254bd760..fbe3769d4 100644 --- a/src/setup/auth_registry.rs +++ b/src/setup/auth_registry.rs @@ -32,7 +32,7 @@ impl AuthRegistry { Ok(()) } - pub fn get(&self, entry_ref: &spec::AuthEntryReference) -> Result { + pub fn get(&self, entry_ref: &spec::AuthEntryReference) -> Result { let entries = self.entries.read().unwrap(); match entries.get(&entry_ref.key) { Some(value) => Ok(serde_json::from_value(value.clone())?),