Skip to content

Commit 465ecb0

Browse files
committed
refactor: introduce PersistenceContext to encapsulate database pool and setup states
1 parent 49aa13e commit 465ecb0

File tree

3 files changed

+60
-14
lines changed

3 files changed

+60
-14
lines changed

src/builder/flow_builder.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,8 @@ impl FlowBuilder {
259259
pub fn new(name: &str) -> PyResult<Self> {
260260
let lib_context = get_lib_context().into_py_result()?;
261261
let existing_flow_ss = lib_context
262-
.all_setup_states
262+
.require_all_setup_states()
263+
.into_py_result()?
263264
.read()
264265
.unwrap()
265266
.flows

src/lib_context.rs

Lines changed: 55 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,15 @@ impl DbPools {
7777
}
7878
}
7979

80+
pub struct PersistenceContext {
81+
pub builtin_db_pool: PgPool,
82+
pub all_setup_states: RwLock<setup::AllSetupState<setup::ExistingMode>>,
83+
}
84+
8085
pub struct LibContext {
8186
pub db_pools: DbPools,
82-
pub builtin_db_pool: Option<PgPool>,
87+
pub persistence_ctx: Option<PersistenceContext>,
8388
pub flows: Mutex<BTreeMap<String, Arc<FlowContext>>>,
84-
pub all_setup_states: RwLock<setup::AllSetupState<setup::ExistingMode>>,
8589
}
8690

8791
impl LibContext {
@@ -100,8 +104,16 @@ impl LibContext {
100104
}
101105

102106
pub fn require_builtin_db_pool(&self) -> Result<&PgPool> {
103-
self.builtin_db_pool
107+
self.persistence_ctx
108+
.as_ref()
109+
.map(|ctx| &ctx.builtin_db_pool)
110+
.ok_or_else(|| anyhow!("Database is required for this operation. Please set COCOINDEX_DATABASE_URL environment variable and call cocoindex.init() with database settings."))
111+
}
112+
113+
pub fn require_all_setup_states(&self) -> Result<&RwLock<setup::AllSetupState<setup::ExistingMode>>> {
114+
self.persistence_ctx
104115
.as_ref()
116+
.map(|ctx| &ctx.all_setup_states)
105117
.ok_or_else(|| anyhow!("Database is required for this operation. Please set COCOINDEX_DATABASE_URL environment variable and call cocoindex.init() with database settings."))
106118
}
107119
}
@@ -123,22 +135,24 @@ pub fn create_lib_context(settings: settings::Settings) -> Result<LibContext> {
123135
});
124136

125137
let db_pools = DbPools::default();
126-
let (pool, all_setup_states) = if let Some(database_spec) = &settings.database {
138+
let persistence_ctx = if let Some(database_spec) = &settings.database {
127139
let (pool, all_setup_states) = get_runtime().block_on(async {
128140
let pool = db_pools.get_pool(database_spec).await?;
129141
let existing_ss = setup::get_existing_setup_state(&pool).await?;
130-
anyhow::Ok((Some(pool), existing_ss))
142+
anyhow::Ok((pool, existing_ss))
131143
})?;
132-
(pool, all_setup_states)
144+
Some(PersistenceContext {
145+
builtin_db_pool: pool,
146+
all_setup_states: RwLock::new(all_setup_states),
147+
})
133148
} else {
134-
// No database configured - create empty setup states
135-
(None, setup::AllSetupState::default())
149+
// No database configured
150+
None
136151
};
137152

138153
Ok(LibContext {
139154
db_pools,
140-
builtin_db_pool: pool,
141-
all_setup_states: RwLock::new(all_setup_states),
155+
persistence_ctx,
142156
flows: Mutex::new(BTreeMap::new()),
143157
})
144158
}
@@ -185,4 +199,35 @@ mod tests {
185199
assert!(settings.database.is_none());
186200
assert_eq!(settings.app_namespace, "test");
187201
}
202+
203+
#[test]
204+
fn test_lib_context_without_database() {
205+
let settings = settings::Settings {
206+
database: None,
207+
app_namespace: "test".to_string(),
208+
};
209+
210+
let lib_context = create_lib_context(settings).unwrap();
211+
assert!(lib_context.persistence_ctx.is_none());
212+
assert!(lib_context.require_builtin_db_pool().is_err());
213+
assert!(lib_context.require_all_setup_states().is_err());
214+
}
215+
216+
#[test]
217+
fn test_persistence_context_type_safety() {
218+
// This test ensures that PersistenceContext groups related fields together
219+
let settings = settings::Settings {
220+
database: Some(settings::DatabaseConnectionSpec {
221+
url: "postgresql://test".to_string(),
222+
user: None,
223+
password: None,
224+
}),
225+
app_namespace: "test".to_string(),
226+
};
227+
228+
// This would fail at runtime due to invalid connection, but we're testing the structure
229+
let result = create_lib_context(settings);
230+
// We expect this to fail due to invalid connection, but the structure should be correct
231+
assert!(result.is_err());
232+
}
188233
}

src/py/mod.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,7 @@ impl SetupStatus {
394394
fn sync_setup(py: Python<'_>) -> PyResult<SetupStatus> {
395395
let lib_context = get_lib_context().into_py_result()?;
396396
let flows = lib_context.flows.lock().unwrap();
397-
let all_setup_states = lib_context.all_setup_states.read().unwrap();
397+
let all_setup_states = lib_context.require_all_setup_states().into_py_result()?.read().unwrap();
398398
py.allow_threads(|| {
399399
get_runtime()
400400
.block_on(async {
@@ -408,7 +408,7 @@ fn sync_setup(py: Python<'_>) -> PyResult<SetupStatus> {
408408
#[pyfunction]
409409
fn drop_setup(py: Python<'_>, flow_names: Vec<String>) -> PyResult<SetupStatus> {
410410
let lib_context = get_lib_context().into_py_result()?;
411-
let all_setup_states = lib_context.all_setup_states.read().unwrap();
411+
let all_setup_states = lib_context.require_all_setup_states().into_py_result()?.read().unwrap();
412412
py.allow_threads(|| {
413413
get_runtime()
414414
.block_on(async {
@@ -422,7 +422,7 @@ fn drop_setup(py: Python<'_>, flow_names: Vec<String>) -> PyResult<SetupStatus>
422422
#[pyfunction]
423423
fn flow_names_with_setup() -> PyResult<Vec<String>> {
424424
let lib_context = get_lib_context().into_py_result()?;
425-
let all_setup_states = lib_context.all_setup_states.read().unwrap();
425+
let all_setup_states = lib_context.require_all_setup_states().into_py_result()?.read().unwrap();
426426
let flow_names = all_setup_states.flows.keys().cloned().collect();
427427
Ok(flow_names)
428428
}

0 commit comments

Comments
 (0)