Skip to content

Commit 3e4080b

Browse files
committed
refactor(init): make lib initiazation logic async
1 parent 6ffb322 commit 3e4080b

File tree

7 files changed

+49
-42
lines changed

7 files changed

+49
-42
lines changed

src/builder/analyzer.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -995,7 +995,7 @@ pub async fn analyze_flow(
995995
impl Future<Output = Result<ExecutionPlan>> + Send + use<>,
996996
)> {
997997
let analyzer_ctx = AnalyzerContext {
998-
lib_ctx: get_lib_context()?,
998+
lib_ctx: get_lib_context().await?,
999999
flow_ctx,
10001000
};
10011001
let root_data_scope = Arc::new(Mutex::new(DataScopeBuilder::new()));
@@ -1109,7 +1109,7 @@ pub async fn analyze_transient_flow<'a>(
11091109
)> {
11101110
let mut root_data_scope = DataScopeBuilder::new();
11111111
let analyzer_ctx = AnalyzerContext {
1112-
lib_ctx: get_lib_context()?,
1112+
lib_ctx: get_lib_context().await?,
11131113
flow_ctx,
11141114
};
11151115
let mut input_fields = vec![];

src/builder/flow_builder.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,8 +247,12 @@ pub struct FlowBuilder {
247247
#[pymethods]
248248
impl FlowBuilder {
249249
#[new]
250-
pub fn new(name: &str) -> PyResult<Self> {
251-
let lib_context = get_lib_context().into_py_result()?;
250+
pub fn new(py: Python<'_>, name: &str) -> PyResult<Self> {
251+
let lib_context = py
252+
.allow_threads(|| -> anyhow::Result<Arc<LibContext>> {
253+
get_runtime().block_on(get_lib_context())
254+
})
255+
.into_py_result()?;
252256
let root_op_scope = OpScope::new(
253257
spec::ROOT_SCOPE_NAME.to_string(),
254258
None,

src/execution/db_tracking_setup.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ impl ResourceSetupChange for TrackingTableSetupChange {
252252

253253
impl TrackingTableSetupChange {
254254
pub async fn apply_change(&self) -> Result<()> {
255-
let lib_context = get_lib_context()?;
255+
let lib_context = get_lib_context().await?;
256256
let pool = lib_context.require_builtin_db_pool()?;
257257
if let Some(desired) = &self.desired_state {
258258
for lagacy_name in self.legacy_tracking_table_names.iter() {

src/lib_context.rs

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ pub fn get_auth_registry() -> &'static Arc<AuthRegistry> {
267267
}
268268

269269
static LIB_INIT: OnceLock<()> = OnceLock::new();
270-
pub fn create_lib_context(settings: settings::Settings) -> Result<LibContext> {
270+
pub async fn create_lib_context(settings: settings::Settings) -> Result<LibContext> {
271271
LIB_INIT.get_or_init(|| {
272272
let _ = env_logger::try_init();
273273

@@ -278,11 +278,8 @@ pub fn create_lib_context(settings: settings::Settings) -> Result<LibContext> {
278278

279279
let db_pools = DbPools::default();
280280
let persistence_ctx = if let Some(database_spec) = &settings.database {
281-
let (pool, all_setup_states) = get_runtime().block_on(async {
282-
let pool = db_pools.get_pool(database_spec).await?;
283-
let existing_ss = setup::get_existing_setup_state(&pool).await?;
284-
anyhow::Ok((pool, existing_ss))
285-
})?;
281+
let pool = db_pools.get_pool(database_spec).await?;
282+
let all_setup_states = setup::get_existing_setup_state(&pool).await?;
286283
Some(PersistenceContext {
287284
builtin_db_pool: pool,
288285
setup_ctx: tokio::sync::RwLock::new(LibSetupContext {
@@ -308,24 +305,26 @@ pub fn create_lib_context(settings: settings::Settings) -> Result<LibContext> {
308305
})
309306
}
310307

311-
pub static LIB_CONTEXT: RwLock<Option<Arc<LibContext>>> = RwLock::new(None);
308+
static GET_SETTINGS_FN: Mutex<Option<Box<dyn Fn() -> settings::Settings + Send + Sync>>> =
309+
Mutex::new(None);
310+
static LIB_CONTEXT: Mutex<Option<Arc<LibContext>>> = Mutex::new(None);
312311

313-
pub(crate) fn init_lib_context(settings: settings::Settings) -> Result<()> {
314-
let mut lib_context_locked = LIB_CONTEXT.write().unwrap();
315-
*lib_context_locked = Some(Arc::new(create_lib_context(settings)?));
312+
pub(crate) async fn init_lib_context(settings: settings::Settings) -> Result<()> {
313+
let mut lib_context_locked = LIB_CONTEXT.lock().unwrap();
314+
*lib_context_locked = Some(Arc::new(create_lib_context(settings).await?));
316315
Ok(())
317316
}
318317

319-
pub(crate) fn get_lib_context() -> Result<Arc<LibContext>> {
320-
let lib_context_locked = LIB_CONTEXT.read().unwrap();
318+
pub(crate) async fn get_lib_context() -> Result<Arc<LibContext>> {
319+
let lib_context_locked = LIB_CONTEXT.lock().unwrap();
321320
lib_context_locked
322321
.as_ref()
323322
.cloned()
324323
.ok_or_else(|| anyhow!("CocoIndex library is not initialized or already stopped"))
325324
}
326325

327326
pub(crate) fn clear_lib_context() {
328-
let mut lib_context_locked = LIB_CONTEXT.write().unwrap();
327+
let mut lib_context_locked = LIB_CONTEXT.lock().unwrap();
329328
*lib_context_locked = None;
330329
}
331330

@@ -339,15 +338,17 @@ mod tests {
339338
assert!(db_pools.pools.lock().unwrap().is_empty());
340339
}
341340

342-
#[test]
343-
fn test_lib_context_without_database() {
344-
let lib_context = create_lib_context(settings::Settings::default()).unwrap();
341+
#[tokio::test]
342+
async fn test_lib_context_without_database() {
343+
let lib_context = create_lib_context(settings::Settings::default())
344+
.await
345+
.unwrap();
345346
assert!(lib_context.persistence_ctx.is_none());
346347
assert!(lib_context.require_builtin_db_pool().is_err());
347348
}
348349

349-
#[test]
350-
fn test_persistence_context_type_safety() {
350+
#[tokio::test]
351+
async fn test_persistence_context_type_safety() {
351352
// This test ensures that PersistenceContext groups related fields together
352353
let settings = settings::Settings {
353354
database: Some(settings::DatabaseConnectionSpec {
@@ -361,7 +362,7 @@ mod tests {
361362
};
362363

363364
// This would fail at runtime due to invalid connection, but we're testing the structure
364-
let result = create_lib_context(settings);
365+
let result = create_lib_context(settings).await;
365366
// We expect this to fail due to invalid connection, but the structure should be correct
366367
assert!(result.is_err());
367368
}

src/ops/shared/postgres.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ pub async fn get_db_pool(
1010
db_ref: Option<&spec::AuthEntryReference<DatabaseConnectionSpec>>,
1111
auth_registry: &AuthRegistry,
1212
) -> Result<PgPool> {
13-
let lib_context = get_lib_context()?;
13+
let lib_context = get_lib_context().await?;
1414
let db_conn_spec = db_ref
1515
.as_ref()
1616
.map(|db_ref| auth_registry.get(db_ref))

src/py/mod.rs

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -97,19 +97,17 @@ impl<T> AnyhowIntoPyResult<T> for anyhow::Result<T> {
9797
#[pyfunction]
9898
fn init(py: Python<'_>, settings: Pythonized<Settings>) -> PyResult<()> {
9999
py.allow_threads(|| -> anyhow::Result<()> {
100-
init_lib_context(settings.into_inner())?;
101-
Ok(())
100+
get_runtime().block_on(async move { init_lib_context(settings.into_inner()).await })
102101
})
103102
.into_py_result()
104103
}
105104

106105
#[pyfunction]
107106
fn start_server(py: Python<'_>, settings: Pythonized<ServerSettings>) -> PyResult<()> {
108107
py.allow_threads(|| -> anyhow::Result<()> {
109-
let server = get_runtime().block_on(server::init_server(
110-
get_lib_context()?,
111-
settings.into_inner(),
112-
))?;
108+
let server = get_runtime().block_on(async move {
109+
server::init_server(get_lib_context().await?, settings.into_inner()).await
110+
})?;
113111
get_runtime().spawn(server);
114112
Ok(())
115113
})
@@ -202,7 +200,7 @@ impl FlowLiveUpdater {
202200
) -> PyResult<Bound<'py, PyAny>> {
203201
let flow = flow.0.clone();
204202
future_into_py(py, async move {
205-
let lib_context = get_lib_context().into_py_result()?;
203+
let lib_context = get_lib_context().await.into_py_result()?;
206204
let live_updater = execution::FlowLiveUpdater::start(
207205
flow,
208206
lib_context.require_builtin_db_pool().into_py_result()?,
@@ -262,7 +260,7 @@ impl Flow {
262260
get_runtime()
263261
.block_on(async {
264262
let exec_plan = self.0.flow.get_execution_plan().await?;
265-
let lib_context = get_lib_context()?;
263+
let lib_context = get_lib_context().await?;
266264
let execution_ctx = self.0.use_execution_ctx().await?;
267265
execution::dumper::evaluate_and_dump(
268266
&exec_plan,
@@ -457,9 +455,9 @@ pub struct SetupChangeBundle(Arc<setup::SetupChangeBundle>);
457455
#[pymethods]
458456
impl SetupChangeBundle {
459457
pub fn describe_async<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
460-
let lib_context = get_lib_context().into_py_result()?;
461458
let bundle = self.0.clone();
462459
future_into_py(py, async move {
460+
let lib_context = get_lib_context().await.into_py_result()?;
463461
bundle.describe(&lib_context).await.into_py_result()
464462
})
465463
}
@@ -469,10 +467,10 @@ impl SetupChangeBundle {
469467
py: Python<'py>,
470468
report_to_stdout: bool,
471469
) -> PyResult<Bound<'py, PyAny>> {
472-
let lib_context = get_lib_context().into_py_result()?;
473470
let bundle = self.0.clone();
474471

475472
future_into_py(py, async move {
473+
let lib_context = get_lib_context().await.into_py_result()?;
476474
let mut stdout = None;
477475
let mut sink = None;
478476
bundle
@@ -493,7 +491,7 @@ impl SetupChangeBundle {
493491
#[pyfunction]
494492
fn flow_names_with_setup_async(py: Python<'_>) -> PyResult<Bound<'_, PyAny>> {
495493
future_into_py(py, async move {
496-
let lib_context = get_lib_context().into_py_result()?;
494+
let lib_context = get_lib_context().await.into_py_result()?;
497495
let setup_ctx = lib_context
498496
.require_persistence_ctx()
499497
.into_py_result()?
@@ -524,11 +522,15 @@ fn make_drop_bundle(flow_names: Vec<String>) -> PyResult<SetupChangeBundle> {
524522
}
525523

526524
#[pyfunction]
527-
fn remove_flow_context(flow_name: String) {
528-
let lib_context_locked = crate::lib_context::LIB_CONTEXT.read().unwrap();
529-
if let Some(lib_context) = lib_context_locked.as_ref() {
530-
lib_context.remove_flow_context(&flow_name)
531-
}
525+
fn remove_flow_context(py: Python<'_>, flow_name: String) -> PyResult<()> {
526+
py.allow_threads(|| -> anyhow::Result<()> {
527+
get_runtime().block_on(async move {
528+
let lib_context = get_lib_context().await.into_py_result()?;
529+
lib_context.remove_flow_context(&flow_name);
530+
Ok(())
531+
})
532+
})
533+
.into_py_result()
532534
}
533535

534536
#[pyfunction]

src/setup/db_metadata.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ impl MetadataTableSetup {
355355
if !self.metadata_table_missing {
356356
return Ok(());
357357
}
358-
let lib_context = get_lib_context()?;
358+
let lib_context = get_lib_context().await?;
359359
let pool = lib_context.require_builtin_db_pool()?;
360360
let query_str = format!(
361361
"CREATE TABLE IF NOT EXISTS {SETUP_METADATA_TABLE_NAME} (

0 commit comments

Comments
 (0)