Skip to content

Commit 2dc16c8

Browse files
authored
feat(db): allow user/password to be configured separately for postgres (#391)
1 parent 39ce910 commit 2dc16c8

File tree

13 files changed

+172
-112
lines changed

13 files changed

+172
-112
lines changed

python/cocoindex/lib.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""
22
Library level functions and states.
33
"""
4-
import asyncio
54
import os
65
import sys
76
import functools
@@ -12,6 +11,7 @@
1211

1312
from . import _engine
1413
from . import flow, query, cli
14+
from .convert import dump_engine_object
1515

1616

1717
def _load_field(target: dict[str, str], name: str, env_name: str, required: bool = False):
@@ -22,24 +22,32 @@ def _load_field(target: dict[str, str], name: str, env_name: str, required: bool
2222
else:
2323
target[name] = value
2424

25+
@dataclass
26+
class DatabaseConnectionSpec:
27+
uri: str
28+
user: str | None = None
29+
password: str | None = None
30+
2531
@dataclass
2632
class Settings:
2733
"""Settings for the cocoindex library."""
28-
database_url: str
34+
database: DatabaseConnectionSpec
2935

3036
@classmethod
3137
def from_env(cls) -> Self:
3238
"""Load settings from environment variables."""
3339

34-
kwargs: dict[str, str] = dict()
35-
_load_field(kwargs, "database_url", "COCOINDEX_DATABASE_URL", required=True)
36-
37-
return cls(**kwargs)
40+
db_kwargs: dict[str, str] = dict()
41+
_load_field(db_kwargs, "uri", "COCOINDEX_DATABASE_URL", required=True)
42+
_load_field(db_kwargs, "user", "COCOINDEX_DATABASE_USER")
43+
_load_field(db_kwargs, "password", "COCOINDEX_DATABASE_PASSWORD")
44+
database = DatabaseConnectionSpec(**db_kwargs)
45+
return cls(database=database)
3846

3947

4048
def init(settings: Settings):
4149
"""Initialize the cocoindex library."""
42-
_engine.init(settings.__dict__)
50+
_engine.init(dump_engine_object(settings))
4351

4452
@dataclass
4553
class ServerSettings:

src/base/spec.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -322,12 +322,17 @@ impl<T> Clone for AuthEntryReference<T> {
322322
}
323323
}
324324

325+
#[derive(Serialize, Deserialize)]
326+
struct UntypedAuthEntryReference<T> {
327+
key: T,
328+
}
329+
325330
impl<T> Serialize for AuthEntryReference<T> {
326331
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
327332
where
328333
S: serde::Serializer,
329334
{
330-
self.key.serialize(serializer)
335+
UntypedAuthEntryReference { key: &self.key }.serialize(serializer)
331336
}
332337
}
333338

@@ -336,8 +341,9 @@ impl<'de, T> Deserialize<'de> for AuthEntryReference<T> {
336341
where
337342
D: serde::Deserializer<'de>,
338343
{
339-
Ok(Self {
340-
key: String::deserialize(deserializer)?,
344+
let untyped_ref = UntypedAuthEntryReference::<String>::deserialize(deserializer)?;
345+
Ok(AuthEntryReference {
346+
key: untyped_ref.key,
341347
_phantom: std::marker::PhantomData,
342348
})
343349
}

src/execution/db_tracking_setup.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ impl ResourceSetupStatusCheck for TrackingTableSetupStatusCheck {
157157
}
158158

159159
async fn apply_change(&self) -> Result<()> {
160-
let pool = &get_lib_context()?.pool;
160+
let pool = &get_lib_context()?.builtin_db_pool;
161161
if let Some(desired) = &self.desired_state {
162162
for lagacy_name in self.legacy_table_names.iter() {
163163
let query = format!(

src/lib_context.rs

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use crate::settings;
66
use crate::setup;
77
use crate::{builder::AnalyzedFlow, execution::query::SimpleSemanticsQueryHandler};
88
use axum::http::StatusCode;
9+
use sqlx::postgres::PgConnectOptions;
910
use sqlx::PgPool;
1011
use std::collections::BTreeMap;
1112
use tokio::runtime::Runtime;
@@ -61,8 +62,40 @@ impl FlowContext {
6162
static TOKIO_RUNTIME: LazyLock<Runtime> = LazyLock::new(|| Runtime::new().unwrap());
6263
static AUTH_REGISTRY: LazyLock<Arc<AuthRegistry>> = LazyLock::new(|| Arc::new(AuthRegistry::new()));
6364

65+
#[derive(Default)]
66+
pub struct DbPools {
67+
pub pools: Mutex<HashMap<(String, Option<String>), Arc<tokio::sync::OnceCell<PgPool>>>>,
68+
}
69+
70+
impl DbPools {
71+
pub async fn get_pool(&self, conn_spec: &settings::DatabaseConnectionSpec) -> Result<PgPool> {
72+
let db_pool_cell = {
73+
let key = (conn_spec.uri.clone(), conn_spec.user.clone());
74+
let mut db_pools = self.pools.lock().unwrap();
75+
db_pools.entry(key).or_default().clone()
76+
};
77+
let pool = db_pool_cell
78+
.get_or_try_init(|| async move {
79+
let mut pg_options: PgConnectOptions = conn_spec.uri.parse()?;
80+
if let Some(user) = &conn_spec.user {
81+
pg_options = pg_options.username(user);
82+
}
83+
if let Some(password) = &conn_spec.password {
84+
pg_options = pg_options.password(password);
85+
}
86+
let pool = PgPool::connect_with(pg_options)
87+
.await
88+
.context("Failed to connect to database")?;
89+
anyhow::Ok(pool)
90+
})
91+
.await?;
92+
Ok(pool.clone())
93+
}
94+
}
95+
6496
pub struct LibContext {
65-
pub pool: PgPool,
97+
pub db_pools: DbPools,
98+
pub builtin_db_pool: PgPool,
6699
pub flows: Mutex<BTreeMap<String, Arc<FlowContext>>>,
67100
pub all_setup_states: RwLock<setup::AllSetupState<setup::ExistingMode>>,
68101
}
@@ -100,13 +133,15 @@ pub fn create_lib_context(settings: settings::Settings) -> Result<LibContext> {
100133
pyo3_async_runtimes::tokio::init_with_runtime(get_runtime()).unwrap();
101134
});
102135

136+
let db_pools = DbPools::default();
103137
let (pool, all_setup_states) = get_runtime().block_on(async {
104-
let pool = PgPool::connect(&settings.database_url).await?;
138+
let pool = db_pools.get_pool(&settings.database).await?;
105139
let existing_ss = setup::get_existing_setup_state(&pool).await?;
106140
anyhow::Ok((pool, existing_ss))
107141
})?;
108142
Ok(LibContext {
109-
pool,
143+
db_pools,
144+
builtin_db_pool: pool,
110145
all_setup_states: RwLock::new(all_setup_states),
111146
flows: Mutex::new(BTreeMap::new()),
112147
})

src/ops/factory_bases.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,11 @@ impl<T: StorageFactoryBase> ExportTargetFactory for T {
420420
StorageFactoryBase::describe_resource(self, &key)
421421
}
422422

423+
fn normalize_setup_key(&self, key: serde_json::Value) -> Result<serde_json::Value> {
424+
let key: T::Key = serde_json::from_value(key.clone())?;
425+
Ok(serde_json::to_value(key)?)
426+
}
427+
423428
fn check_state_compatibility(
424429
&self,
425430
desired_state: &serde_json::Value,

src/ops/interface.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,10 @@ pub trait ExportTargetFactory: Send + Sync {
198198
auth_registry: &Arc<AuthRegistry>,
199199
) -> Result<Box<dyn setup::ResourceSetupStatusCheck>>;
200200

201+
/// Normalize the key. e.g. the JSON format may change (after code change, e.g. new optional field or field ordering), even if the underlying value is not changed.
202+
/// This should always return the canonical serialized form.
203+
fn normalize_setup_key(&self, key: serde_json::Value) -> Result<serde_json::Value>;
204+
201205
fn check_state_compatibility(
202206
&self,
203207
desired_state: &serde_json::Value,

src/ops/registration.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ fn register_executor_factories(registry: &mut ExecutorFactoryRegistry) -> Result
1313
functions::split_recursively::Factory.register(registry)?;
1414
functions::extract_by_llm::Factory.register(registry)?;
1515

16-
Arc::new(storages::postgres::Factory::default()).register(registry)?;
16+
storages::postgres::Factory::default().register(registry)?;
1717
Arc::new(storages::qdrant::Factory::default()).register(registry)?;
1818

1919
storages::neo4j::Factory::new().register(registry)?;

0 commit comments

Comments
 (0)