From 71b2ee17b4ed352c3e995de1562a40fde39ef016 Mon Sep 17 00:00:00 2001 From: LJ Date: Wed, 9 Apr 2025 18:09:25 -0700 Subject: [PATCH] Add a `AuthRegistry` and put it in `AnalyzerContext`. --- src/base/spec.rs | 5 +++++ src/builder/analyzed_flow.rs | 15 +++++---------- src/builder/analyzer.rs | 10 +++++++--- src/builder/flow_builder.rs | 6 +++++- src/lib_context.rs | 2 ++ src/ops/interface.rs | 1 + src/prelude.rs | 7 ++++++- src/setup/auth_registry.rs | 36 ++++++++++++++++++++++++++++++++++++ src/setup/mod.rs | 2 ++ 9 files changed, 69 insertions(+), 15 deletions(-) create mode 100644 src/setup/auth_registry.rs diff --git a/src/base/spec.rs b/src/base/spec.rs index 1d8b344f4..2321610ff 100644 --- a/src/base/spec.rs +++ b/src/base/spec.rs @@ -282,3 +282,8 @@ pub struct SimpleSemanticsQueryHandlerSpec { pub query_transform_flow: TransientFlowSpec, pub default_similarity_metric: VectorSimilarityMetric, } + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct AuthEntryReference { + pub key: String, +} diff --git a/src/builder/analyzed_flow.rs b/src/builder/analyzed_flow.rs index 68e8ecbb2..1485b3a8d 100644 --- a/src/builder/analyzed_flow.rs +++ b/src/builder/analyzed_flow.rs @@ -1,18 +1,11 @@ -use std::sync::Arc; +use crate::prelude::*; use super::{analyzer, plan}; use crate::{ - api_error, - base::{schema, spec}, ops::registry::ExecutorFactoryRegistry, service::error::{shared_ok, SharedError, SharedResultExt}, setup::{self, ObjectSetupStatusCheck}, }; -use anyhow::Result; -use futures::{ - future::{BoxFuture, Shared}, - FutureExt, -}; pub struct AnalyzedFlow { pub flow_instance: spec::FlowInstanceSpec, @@ -28,8 +21,9 @@ impl AnalyzedFlow { flow_instance: crate::base::spec::FlowInstanceSpec, existing_flow_ss: Option<&setup::FlowSetupState>, registry: &ExecutorFactoryRegistry, + auth_registry: Arc, ) -> Result { - let ctx = analyzer::build_flow_instance_context(&flow_instance.name); + let ctx = analyzer::build_flow_instance_context(&flow_instance.name, auth_registry); let (data_schema, execution_plan_fut, desired_state) = analyzer::analyze_flow(&flow_instance, &ctx, existing_flow_ss, registry)?; let setup_status_check = @@ -79,8 +73,9 @@ impl AnalyzedTransientFlow { pub async fn from_transient_flow( transient_flow: spec::TransientFlowSpec, registry: &ExecutorFactoryRegistry, + auth_registry: Arc, ) -> Result { - let ctx = analyzer::build_flow_instance_context(&transient_flow.name); + let ctx = analyzer::build_flow_instance_context(&transient_flow.name, auth_registry); let (output_type, data_schema, execution_plan_fut) = analyzer::analyze_transient_flow(&transient_flow, &ctx, registry)?; Ok(Self { diff --git a/src/builder/analyzer.rs b/src/builder/analyzer.rs index 973407ef8..0aabedf75 100644 --- a/src/builder/analyzer.rs +++ b/src/builder/analyzer.rs @@ -5,8 +5,8 @@ use std::{collections::HashMap, future::Future, sync::Arc}; use super::plan::*; use crate::execution::db_tracking_setup; use crate::setup::{ - self, DesiredMode, FlowSetupMetadata, FlowSetupState, ResourceIdentifier, SourceSetupState, - TargetSetupState, TargetSetupStateCommon, + self, AuthRegistry, DesiredMode, FlowSetupMetadata, FlowSetupState, ResourceIdentifier, + SourceSetupState, TargetSetupState, TargetSetupStateCommon, }; use crate::utils::fingerprint::Fingerprinter; use crate::{ @@ -1027,9 +1027,13 @@ impl AnalyzerContext<'_> { } } -pub fn build_flow_instance_context(flow_inst_name: &str) -> Arc { +pub fn build_flow_instance_context( + flow_inst_name: &str, + auth_registry: Arc, +) -> Arc { Arc::new(FlowInstanceContext { flow_instance_name: flow_inst_name.to_string(), + auth_registry, }) } diff --git a/src/builder/flow_builder.rs b/src/builder/flow_builder.rs index 2735a95e0..a60b1f32d 100644 --- a/src/builder/flow_builder.rs +++ b/src/builder/flow_builder.rs @@ -347,9 +347,11 @@ impl FlowBuilder { .get(name) .cloned(); let root_data_scope = Arc::new(Mutex::new(DataScopeBuilder::new())); + let flow_inst_context = + build_flow_instance_context(name, lib_context.auth_registry.clone()); let result = Self { lib_context, - flow_inst_context: build_flow_instance_context(name), + flow_inst_context, existing_flow_ss, root_data_scope_ref: DataScopeRef(Arc::new(DataScopeRefInfo { @@ -648,6 +650,7 @@ impl FlowBuilder { spec, self.existing_flow_ss.as_ref(), &crate::ops::executor_factory_registry(), + self.lib_context.auth_registry.clone(), )) }) .into_py_result()?; @@ -688,6 +691,7 @@ impl FlowBuilder { get_runtime().block_on(super::AnalyzedTransientFlow::from_transient_flow( spec, &crate::ops::executor_factory_registry(), + self.lib_context.auth_registry.clone(), )) }) .into_py_result()?; diff --git a/src/lib_context.rs b/src/lib_context.rs index 9b71ece40..54f27aae8 100644 --- a/src/lib_context.rs +++ b/src/lib_context.rs @@ -63,6 +63,7 @@ static TOKIO_RUNTIME: LazyLock = LazyLock::new(|| Runtime::new().unwrap pub struct LibContext { pub pool: PgPool, pub flows: Mutex>>, + pub auth_registry: Arc, pub all_setup_states: RwLock>, } @@ -103,6 +104,7 @@ pub fn create_lib_context(settings: settings::Settings) -> Result { pool, all_setup_states: RwLock::new(all_setup_states), flows: Mutex::new(BTreeMap::new()), + auth_registry: Arc::new(AuthRegistry::new()), }) } diff --git a/src/ops/interface.rs b/src/ops/interface.rs index 90b9356e6..a487c894b 100644 --- a/src/ops/interface.rs +++ b/src/ops/interface.rs @@ -12,6 +12,7 @@ use serde::Serialize; pub struct FlowInstanceContext { pub flow_instance_name: String, + pub auth_registry: Arc, } #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] diff --git a/src/prelude.rs b/src/prelude.rs index afcfd5905..66fb99314 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -3,7 +3,11 @@ pub(crate) use anyhow::Result; pub(crate) use async_trait::async_trait; pub(crate) use chrono::{DateTime, Utc}; -pub(crate) use futures::{future::BoxFuture, prelude::*, stream::BoxStream}; +pub(crate) use futures::{ + future::{BoxFuture, Shared}, + prelude::*, + stream::BoxStream, +}; pub(crate) use futures::{FutureExt, StreamExt}; pub(crate) use indexmap::{IndexMap, IndexSet}; pub(crate) use itertools::Itertools; @@ -19,6 +23,7 @@ pub(crate) use crate::execution; pub(crate) use crate::lib_context::{get_lib_context, get_runtime, FlowContext, LibContext}; pub(crate) use crate::ops::interface; pub(crate) use crate::service::error::ApiError; +pub(crate) use crate::setup::AuthRegistry; pub(crate) use crate::{api_bail, api_error}; diff --git a/src/setup/auth_registry.rs b/src/setup/auth_registry.rs new file mode 100644 index 000000000..50aa3560f --- /dev/null +++ b/src/setup/auth_registry.rs @@ -0,0 +1,36 @@ +use std::collections::hash_map; + +use crate::prelude::*; + +pub struct AuthRegistry { + entries: RwLock>, +} + +impl AuthRegistry { + pub fn new() -> Self { + Self { + entries: RwLock::new(HashMap::new()), + } + } + + pub fn add(&self, key: String, value: serde_json::Value) -> Result<()> { + let mut entries = self.entries.write().unwrap(); + match entries.entry(key) { + hash_map::Entry::Occupied(entry) => { + api_bail!("Auth entry already exists: {}", entry.key()); + } + hash_map::Entry::Vacant(entry) => { + entry.insert(value); + } + } + Ok(()) + } + + pub fn get(&self, entry_ref: &spec::AuthEntryReference) -> Result { + let entries = self.entries.read().unwrap(); + match entries.get(&entry_ref.key) { + Some(value) => Ok(serde_json::from_value(value.clone())?), + None => api_bail!("Auth entry not found: {}", entry_ref.key), + } + } +} diff --git a/src/setup/mod.rs b/src/setup/mod.rs index 46a88c39f..7aeeab876 100644 --- a/src/setup/mod.rs +++ b/src/setup/mod.rs @@ -1,6 +1,8 @@ +mod auth_registry; mod db_metadata; mod driver; mod states; +pub use auth_registry::AuthRegistry; pub use driver::*; pub use states::*;