Skip to content

Commit c9f06dd

Browse files
authored
Merge branch 'cocoindex-io:main' into feat-anthropic-dataflow
2 parents 67bbd48 + c7f6fbc commit c9f06dd

File tree

20 files changed

+301
-145
lines changed

20 files changed

+301
-145
lines changed

docs/docs/core/initialization.mdx

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,29 @@ This takes care of the following effects:
4949

5050
The following environment variables are supported:
5151

52-
* `COCOINDEX_DATABASE_URL`: The URL of the Postgres database to use as the internal storage, e.g. `postgres://cocoindex:cocoindex@localhost/cocoindex`
52+
* `COCOINDEX_DATABASE_URL` (required): The URI of the Postgres database to use as the internal storage, e.g. `postgres://cocoindex:cocoindex@localhost/cocoindex`
53+
* `COCOINDEX_DATABASE_USER` (optional): The username for the Postgres database. If not provided, username will come from `COCOINDEX_DATABASE_URL`.
54+
* `COCOINDEX_DATABASE_PASSWORD` (optional): The password for the Postgres database. If not provided, password will come from `COCOINDEX_DATABASE_URL`.
5355

5456
## Explicit Initialization
5557

5658
Alternatively, for flexibility, you can also explicitly initialize the library by the `init()` function:
5759

60+
### Settings
61+
62+
It takes a `Settings` object as argument, which is a dataclass that contains the following fields:
63+
64+
* `database` (type: `DatabaseConnectionSpec`, required): The connection to the Postgres database.
65+
66+
#### DatabaseConnectionSpec
67+
68+
`DatabaseConnectionSpec` has the following fields:
69+
* `uri` (type: `str`, required): The URI of the Postgres database to use as the internal storage, e.g. `postgres://cocoindex:cocoindex@localhost/cocoindex`.
70+
* `user` (type: `str`, optional): The username for the Postgres database. If not provided, username will come from `uri`.
71+
* `password` (type: `str`, optional): The password for the Postgres database. If not provided, password will come from `uri`.
72+
73+
### Example
74+
5875
<Tabs>
5976
<TabItem value="python" label="Python" default>
6077

@@ -63,7 +80,11 @@ import cocoindex
6380

6481
def main():
6582
...
66-
cocoindex.init(cocoindex.Settings(database_url="postgres://cocoindex:cocoindex@localhost/cocoindex"))
83+
cocoindex.init(
84+
cocoindex.Settings(
85+
database=cocoindex.DatabaseConnectionSpec(
86+
uri="postgres://cocoindex:cocoindex@localhost/cocoindex"
87+
)))
6788
...
6889

6990
...

docs/docs/ops/storages.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@ It should be a unique table, meaning that no other export target should export t
3636

3737
The spec takes the following fields:
3838

39-
* `database_url` (type: `str`, optional): The URL of the Postgres database to use as the internal storage, e.g. `postgres://cocoindex:cocoindex@localhost/cocoindex`. If unspecified, will use the same database as the [internal storage](/docs/core/basics#internal-storage).
39+
* `database` (type: [auth reference](../core/flow_def#auth-registry) to `DatabaseConnectionSpec`, optional): The connection to the Postgres database.
40+
See [DatabaseConnectionSpec](../core/initialization#databaseconnectionspec) for its specific fields.
41+
If not provided, will use the same database as the [internal storage](/docs/core/basics#internal-storage).
4042

4143
* `table_name` (type: `str`, optional): The name of the table to store to. If unspecified, will generate a new automatically. We recommend specifying a name explicitly if you want to directly query the table. It can be omitted if you want to use CocoIndex's query handlers to query the table.
4244

python/cocoindex/lib.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""
22
Library level functions and states.
33
"""
4-
import asyncio
54
import os
65
import sys
76
import functools
@@ -12,6 +11,7 @@
1211

1312
from . import _engine
1413
from . import flow, query, cli
14+
from .convert import dump_engine_object
1515

1616

1717
def _load_field(target: dict[str, str], name: str, env_name: str, required: bool = False):
@@ -22,24 +22,32 @@ def _load_field(target: dict[str, str], name: str, env_name: str, required: bool
2222
else:
2323
target[name] = value
2424

25+
@dataclass
26+
class DatabaseConnectionSpec:
27+
uri: str
28+
user: str | None = None
29+
password: str | None = None
30+
2531
@dataclass
2632
class Settings:
2733
"""Settings for the cocoindex library."""
28-
database_url: str
34+
database: DatabaseConnectionSpec
2935

3036
@classmethod
3137
def from_env(cls) -> Self:
3238
"""Load settings from environment variables."""
3339

34-
kwargs: dict[str, str] = dict()
35-
_load_field(kwargs, "database_url", "COCOINDEX_DATABASE_URL", required=True)
36-
37-
return cls(**kwargs)
40+
db_kwargs: dict[str, str] = dict()
41+
_load_field(db_kwargs, "uri", "COCOINDEX_DATABASE_URL", required=True)
42+
_load_field(db_kwargs, "user", "COCOINDEX_DATABASE_USER")
43+
_load_field(db_kwargs, "password", "COCOINDEX_DATABASE_PASSWORD")
44+
database = DatabaseConnectionSpec(**db_kwargs)
45+
return cls(database=database)
3846

3947

4048
def init(settings: Settings):
4149
"""Initialize the cocoindex library."""
42-
_engine.init(settings.__dict__)
50+
_engine.init(dump_engine_object(settings))
4351

4452
@dataclass
4553
class ServerSettings:

python/cocoindex/storages.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
class Postgres(op.StorageSpec):
1010
"""Storage powered by Postgres and pgvector."""
1111

12-
database_url: str | None = None
12+
database: AuthEntryReference | None = None
1313
table_name: str | None = None
1414

1515
@dataclass

src/base/spec.rs

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,69 @@ pub struct SimpleSemanticsQueryHandlerSpec {
296296
pub default_similarity_metric: VectorSimilarityMetric,
297297
}
298298

299-
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
300-
pub struct AuthEntryReference {
299+
pub struct AuthEntryReference<T> {
301300
pub key: String,
301+
_phantom: std::marker::PhantomData<T>,
302+
}
303+
304+
impl<T> std::fmt::Debug for AuthEntryReference<T> {
305+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
306+
write!(f, "AuthEntryReference({})", self.key)
307+
}
308+
}
309+
310+
impl<T> std::fmt::Display for AuthEntryReference<T> {
311+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
312+
write!(f, "AuthEntryReference({})", self.key)
313+
}
314+
}
315+
316+
impl<T> Clone for AuthEntryReference<T> {
317+
fn clone(&self) -> Self {
318+
Self {
319+
key: self.key.clone(),
320+
_phantom: std::marker::PhantomData,
321+
}
322+
}
323+
}
324+
325+
#[derive(Serialize, Deserialize)]
326+
struct UntypedAuthEntryReference<T> {
327+
key: T,
328+
}
329+
330+
impl<T> Serialize for AuthEntryReference<T> {
331+
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
332+
where
333+
S: serde::Serializer,
334+
{
335+
UntypedAuthEntryReference { key: &self.key }.serialize(serializer)
336+
}
337+
}
338+
339+
impl<'de, T> Deserialize<'de> for AuthEntryReference<T> {
340+
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
341+
where
342+
D: serde::Deserializer<'de>,
343+
{
344+
let untyped_ref = UntypedAuthEntryReference::<String>::deserialize(deserializer)?;
345+
Ok(AuthEntryReference {
346+
key: untyped_ref.key,
347+
_phantom: std::marker::PhantomData,
348+
})
349+
}
350+
}
351+
352+
impl<T> PartialEq for AuthEntryReference<T> {
353+
fn eq(&self, other: &Self) -> bool {
354+
self.key == other.key
355+
}
356+
}
357+
358+
impl<T> Eq for AuthEntryReference<T> {}
359+
360+
impl<T> std::hash::Hash for AuthEntryReference<T> {
361+
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
362+
self.key.hash(state);
363+
}
302364
}

src/builder/analyzed_flow.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ impl AnalyzedFlow {
3030
registry,
3131
)?;
3232
let setup_status_check =
33-
setup::check_flow_setup_status(Some(&desired_state), existing_flow_ss)?;
33+
setup::check_flow_setup_status(Some(&desired_state), existing_flow_ss).await?;
3434
let execution_plan = if setup_status_check.is_up_to_date() {
3535
Some(
3636
async move {

src/execution/db_tracking_setup.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ impl ResourceSetupStatusCheck for TrackingTableSetupStatusCheck {
157157
}
158158

159159
async fn apply_change(&self) -> Result<()> {
160-
let pool = &get_lib_context()?.pool;
160+
let pool = &get_lib_context()?.builtin_db_pool;
161161
if let Some(desired) = &self.desired_state {
162162
for lagacy_name in self.legacy_table_names.iter() {
163163
let query = format!(

src/lib_context.rs

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use crate::settings;
66
use crate::setup;
77
use crate::{builder::AnalyzedFlow, execution::query::SimpleSemanticsQueryHandler};
88
use axum::http::StatusCode;
9+
use sqlx::postgres::PgConnectOptions;
910
use sqlx::PgPool;
1011
use std::collections::BTreeMap;
1112
use tokio::runtime::Runtime;
@@ -61,8 +62,40 @@ impl FlowContext {
6162
static TOKIO_RUNTIME: LazyLock<Runtime> = LazyLock::new(|| Runtime::new().unwrap());
6263
static AUTH_REGISTRY: LazyLock<Arc<AuthRegistry>> = LazyLock::new(|| Arc::new(AuthRegistry::new()));
6364

65+
#[derive(Default)]
66+
pub struct DbPools {
67+
pub pools: Mutex<HashMap<(String, Option<String>), Arc<tokio::sync::OnceCell<PgPool>>>>,
68+
}
69+
70+
impl DbPools {
71+
pub async fn get_pool(&self, conn_spec: &settings::DatabaseConnectionSpec) -> Result<PgPool> {
72+
let db_pool_cell = {
73+
let key = (conn_spec.uri.clone(), conn_spec.user.clone());
74+
let mut db_pools = self.pools.lock().unwrap();
75+
db_pools.entry(key).or_default().clone()
76+
};
77+
let pool = db_pool_cell
78+
.get_or_try_init(|| async move {
79+
let mut pg_options: PgConnectOptions = conn_spec.uri.parse()?;
80+
if let Some(user) = &conn_spec.user {
81+
pg_options = pg_options.username(user);
82+
}
83+
if let Some(password) = &conn_spec.password {
84+
pg_options = pg_options.password(password);
85+
}
86+
let pool = PgPool::connect_with(pg_options)
87+
.await
88+
.context("Failed to connect to database")?;
89+
anyhow::Ok(pool)
90+
})
91+
.await?;
92+
Ok(pool.clone())
93+
}
94+
}
95+
6496
pub struct LibContext {
65-
pub pool: PgPool,
97+
pub db_pools: DbPools,
98+
pub builtin_db_pool: PgPool,
6699
pub flows: Mutex<BTreeMap<String, Arc<FlowContext>>>,
67100
pub all_setup_states: RwLock<setup::AllSetupState<setup::ExistingMode>>,
68101
}
@@ -100,13 +133,15 @@ pub fn create_lib_context(settings: settings::Settings) -> Result<LibContext> {
100133
pyo3_async_runtimes::tokio::init_with_runtime(get_runtime()).unwrap();
101134
});
102135

136+
let db_pools = DbPools::default();
103137
let (pool, all_setup_states) = get_runtime().block_on(async {
104-
let pool = PgPool::connect(&settings.database_url).await?;
138+
let pool = db_pools.get_pool(&settings.database).await?;
105139
let existing_ss = setup::get_existing_setup_state(&pool).await?;
106140
anyhow::Ok((pool, existing_ss))
107141
})?;
108142
Ok(LibContext {
109-
pool,
143+
db_pools,
144+
builtin_db_pool: pool,
110145
all_setup_states: RwLock::new(all_setup_states),
111146
flows: Mutex::new(BTreeMap::new()),
112147
})

src/ops/factory_bases.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ pub trait StorageFactoryBase: ExportTargetFactory + Send + Sync + 'static {
304304

305305
/// Will not be called if it's setup by user.
306306
/// It returns an error if the target only supports setup by user.
307-
fn check_setup_status(
307+
async fn check_setup_status(
308308
&self,
309309
key: Self::Key,
310310
desired_state: Option<Self::SetupState>,
@@ -392,7 +392,7 @@ impl<T: StorageFactoryBase> ExportTargetFactory for T {
392392
Ok((data_coll_output, decl_output))
393393
}
394394

395-
fn check_setup_status(
395+
async fn check_setup_status(
396396
&self,
397397
key: &serde_json::Value,
398398
desired_state: Option<serde_json::Value>,
@@ -410,7 +410,8 @@ impl<T: StorageFactoryBase> ExportTargetFactory for T {
410410
desired_state,
411411
existing_states,
412412
auth_registry,
413-
)?;
413+
)
414+
.await?;
414415
Ok(Box::new(status_check))
415416
}
416417

@@ -419,6 +420,11 @@ impl<T: StorageFactoryBase> ExportTargetFactory for T {
419420
StorageFactoryBase::describe_resource(self, &key)
420421
}
421422

423+
fn normalize_setup_key(&self, key: serde_json::Value) -> Result<serde_json::Value> {
424+
let key: T::Key = serde_json::from_value(key.clone())?;
425+
Ok(serde_json::to_value(key)?)
426+
}
427+
422428
fn check_state_compatibility(
423429
&self,
424430
desired_state: &serde_json::Value,

src/ops/interface.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,14 +190,18 @@ pub trait ExportTargetFactory: Send + Sync {
190190

191191
/// Will not be called if it's setup by user.
192192
/// It returns an error if the target only supports setup by user.
193-
fn check_setup_status(
193+
async fn check_setup_status(
194194
&self,
195195
key: &serde_json::Value,
196196
desired_state: Option<serde_json::Value>,
197197
existing_states: setup::CombinedState<serde_json::Value>,
198198
auth_registry: &Arc<AuthRegistry>,
199199
) -> Result<Box<dyn setup::ResourceSetupStatusCheck>>;
200200

201+
/// Normalize the key. e.g. the JSON format may change (after code change, e.g. new optional field or field ordering), even if the underlying value is not changed.
202+
/// This should always return the canonical serialized form.
203+
fn normalize_setup_key(&self, key: serde_json::Value) -> Result<serde_json::Value>;
204+
201205
fn check_state_compatibility(
202206
&self,
203207
desired_state: &serde_json::Value,

0 commit comments

Comments
 (0)