Skip to content

Commit a812d93

Browse files
authored
feat(postgres-sql-attachments): add PostgresSqlAttachment (#1141)
1 parent 43c988f commit a812d93

File tree

8 files changed

+206
-37
lines changed

8 files changed

+206
-37
lines changed

python/cocoindex/flow.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,10 @@ def export(
437437
target_name,
438438
_spec_kind(target_spec),
439439
dump_engine_object(target_spec),
440-
dump_engine_object(attachments),
440+
[
441+
{"kind": _spec_kind(att), **dump_engine_object(att)}
442+
for att in attachments
443+
],
441444
dump_engine_object(index_options),
442445
self._engine_data_collector,
443446
setup_by_user,

python/cocoindex/targets/_engine_builtin_specs.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,14 @@ class Postgres(op.TargetSpec):
1616
table_name: str | None = None
1717

1818

19+
class PostgresSqlAttachment(op.TargetAttachmentSpec):
20+
"""Attachment to execute specified SQL statements for Postgres targets."""
21+
22+
name: str
23+
setup_sql: str
24+
teardown_sql: str | None = None
25+
26+
1927
@dataclass
2028
class QdrantConnection:
2129
"""Connection spec for Qdrant."""

src/ops/factory_bases.rs

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -642,32 +642,48 @@ pub struct TypedTargetAttachmentState<F: TargetSpecificAttachmentFactoryBase + ?
642642
}
643643

644644
/// A factory for target-specific attachments.
645+
#[async_trait]
645646
pub trait TargetSpecificAttachmentFactoryBase: Send + Sync + 'static {
647+
type TargetKey: Debug + Clone + Serialize + DeserializeOwned + Eq + Hash + Send + Sync;
646648
type TargetSpec: DeserializeOwned + Send + Sync;
647649
type Spec: DeserializeOwned + Send + Sync;
648650
type SetupKey: Debug + Clone + Serialize + DeserializeOwned + Eq + Hash + Send + Sync;
649651
type SetupState: Debug + Clone + Serialize + DeserializeOwned + Send + Sync;
650652
type SetupChange: interface::AttachmentSetupChange + Send + Sync;
651653

654+
fn name(&self) -> &str;
655+
652656
fn get_state(
653657
&self,
654658
target_name: &str,
655659
target_spec: &Self::TargetSpec,
656660
attachment_spec: Self::Spec,
657661
) -> Result<TypedTargetAttachmentState<Self>>;
658662

659-
fn diff_setup_states(
663+
async fn diff_setup_states(
660664
&self,
661-
key: &serde_json::Value,
662-
new_state: Option<serde_json::Value>,
663-
existing_states: setup::CombinedState<serde_json::Value>,
665+
target_key: &Self::TargetKey,
666+
attachment_key: &Self::SetupKey,
667+
new_state: Option<Self::SetupState>,
668+
existing_states: setup::CombinedState<Self::SetupState>,
669+
context: &interface::FlowInstanceContext,
664670
) -> Result<Option<Self::SetupChange>>;
665671

666672
/// Deserialize the setup key from a JSON value.
667673
/// You can override this method to provide a custom deserialization logic, e.g. to perform backward compatible deserialization.
668674
fn deserialize_setup_key(key: serde_json::Value) -> Result<Self::SetupKey> {
669675
Ok(utils::deser::from_json_value(key)?)
670676
}
677+
678+
fn register(self, registry: &mut ExecutorFactoryRegistry) -> Result<()>
679+
where
680+
Self: Sized,
681+
{
682+
registry.register(
683+
self.name().to_string(),
684+
ExecutorFactory::TargetAttachment(Arc::new(self)),
685+
)
686+
}
671687
}
672688

673689
#[async_trait]
@@ -695,19 +711,25 @@ impl<T: TargetSpecificAttachmentFactoryBase> TargetAttachmentFactory for T {
695711
})
696712
}
697713

698-
fn diff_setup_states(
714+
async fn diff_setup_states(
699715
&self,
700-
key: &serde_json::Value,
716+
target_key: &serde_json::Value,
717+
attachment_key: &serde_json::Value,
701718
new_state: Option<serde_json::Value>,
702719
existing_states: setup::CombinedState<serde_json::Value>,
720+
context: &interface::FlowInstanceContext,
703721
) -> Result<Option<Box<dyn AttachmentSetupChange + Send + Sync>>> {
704-
let setup_change = self.diff_setup_states(
705-
&utils::deser::from_json_value(key.clone())?,
706-
new_state
707-
.map(|v| utils::deser::from_json_value(v))
708-
.transpose()?,
709-
from_json_combined_state(existing_states)?,
710-
)?;
722+
let setup_change = self
723+
.diff_setup_states(
724+
&utils::deser::from_json_value(target_key.clone())?,
725+
&utils::deser::from_json_value(attachment_key.clone())?,
726+
new_state
727+
.map(|v| utils::deser::from_json_value(v))
728+
.transpose()?,
729+
from_json_combined_state(existing_states)?,
730+
context,
731+
)
732+
.await?;
711733
Ok(setup_change.map(|s| Box::new(s) as Box<dyn AttachmentSetupChange + Send + Sync>))
712734
}
713735
}

src/ops/interface.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -323,11 +323,12 @@ pub struct TargetAttachmentState {
323323

324324
#[async_trait]
325325
pub trait AttachmentSetupChange {
326-
fn describe_change(&self) -> String;
326+
fn describe_changes(&self) -> Vec<String>;
327327

328328
async fn apply_change(&self) -> Result<()>;
329329
}
330330

331+
#[async_trait]
331332
pub trait TargetAttachmentFactory: Send + Sync {
332333
/// Normalize the key. e.g. the JSON format may change (after code change, e.g. new optional field or field ordering), even if the underlying value is not changed.
333334
/// This should always return the canonical serialized form.
@@ -341,11 +342,13 @@ pub trait TargetAttachmentFactory: Send + Sync {
341342
) -> Result<TargetAttachmentState>;
342343

343344
/// Should return Some if and only if any changes are needed.
344-
fn diff_setup_states(
345+
async fn diff_setup_states(
345346
&self,
346-
key: &serde_json::Value,
347+
target_key: &serde_json::Value,
348+
attachment_key: &serde_json::Value,
347349
new_state: Option<serde_json::Value>,
348350
existing_states: setup::CombinedState<serde_json::Value>,
351+
context: &interface::FlowInstanceContext,
349352
) -> Result<Option<Box<dyn AttachmentSetupChange + Send + Sync>>>;
350353
}
351354

src/ops/registration.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ fn register_executor_factories(registry: &mut ExecutorFactoryRegistry) -> Result
2020
functions::embed_text::register(registry)?;
2121
functions::split_by_separators::register(registry)?;
2222

23-
targets::postgres::Factory::default().register(registry)?;
23+
targets::postgres::register(registry)?;
2424
targets::qdrant::register(registry)?;
2525
targets::kuzu::register(registry, reqwest_client)?;
2626

src/ops/targets/postgres.rs

Lines changed: 122 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -248,8 +248,7 @@ impl ExportContext {
248248
}
249249
}
250250

251-
#[derive(Default)]
252-
pub struct Factory {}
251+
struct TargetFactory;
253252

254253
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
255254
pub struct TableId {
@@ -614,7 +613,7 @@ impl SetupChange {
614613
}
615614

616615
#[async_trait]
617-
impl TargetFactoryBase for Factory {
616+
impl TargetFactoryBase for TargetFactory {
618617
type Spec = Spec;
619618
type DeclarationSpec = ();
620619
type SetupState = SetupState;
@@ -752,3 +751,123 @@ impl TargetFactoryBase for Factory {
752751
Ok(())
753752
}
754753
}
754+
755+
////////////////////////////////////////////////////////////
756+
// Attachment Factory
757+
////////////////////////////////////////////////////////////
758+
759+
#[derive(Debug, Clone, Serialize, Deserialize)]
760+
pub struct SqlStatementAttachmentSpec {
761+
name: String,
762+
setup_sql: String,
763+
teardown_sql: Option<String>,
764+
}
765+
766+
#[derive(Debug, Clone, Serialize, Deserialize)]
767+
pub struct SqlStatementAttachmentState {
768+
setup_sql: String,
769+
teardown_sql: Option<String>,
770+
}
771+
772+
pub struct SqlStatementAttachmentSetupChange {
773+
db_pool: PgPool,
774+
setup_sql_to_run: Option<String>,
775+
teardown_sql_to_run: IndexSet<String>,
776+
}
777+
778+
#[async_trait]
779+
impl AttachmentSetupChange for SqlStatementAttachmentSetupChange {
780+
fn describe_changes(&self) -> Vec<String> {
781+
let mut result = vec![];
782+
for teardown_sql in self.teardown_sql_to_run.iter() {
783+
result.push(format!("Run teardown SQL: {}", teardown_sql));
784+
}
785+
if let Some(setup_sql) = &self.setup_sql_to_run {
786+
result.push(format!("Run setup SQL: {}", setup_sql));
787+
}
788+
result
789+
}
790+
791+
async fn apply_change(&self) -> Result<()> {
792+
for teardown_sql in self.teardown_sql_to_run.iter() {
793+
sqlx::query(teardown_sql).execute(&self.db_pool).await?;
794+
}
795+
if let Some(setup_sql) = &self.setup_sql_to_run {
796+
sqlx::query(setup_sql).execute(&self.db_pool).await?;
797+
}
798+
Ok(())
799+
}
800+
}
801+
802+
struct SqlAttachmentFactory;
803+
804+
#[async_trait]
805+
impl TargetSpecificAttachmentFactoryBase for SqlAttachmentFactory {
806+
type TargetKey = TableId;
807+
type TargetSpec = Spec;
808+
type Spec = SqlStatementAttachmentSpec;
809+
type SetupKey = String;
810+
type SetupState = SqlStatementAttachmentState;
811+
type SetupChange = SqlStatementAttachmentSetupChange;
812+
813+
fn name(&self) -> &str {
814+
"PostgresSqlAttachment"
815+
}
816+
817+
fn get_state(
818+
&self,
819+
_target_name: &str,
820+
_target_spec: &Spec,
821+
attachment_spec: SqlStatementAttachmentSpec,
822+
) -> Result<TypedTargetAttachmentState<Self>> {
823+
Ok(TypedTargetAttachmentState {
824+
setup_key: attachment_spec.name,
825+
setup_state: SqlStatementAttachmentState {
826+
setup_sql: attachment_spec.setup_sql,
827+
teardown_sql: attachment_spec.teardown_sql,
828+
},
829+
})
830+
}
831+
832+
async fn diff_setup_states(
833+
&self,
834+
target_key: &TableId,
835+
_attachment_key: &String,
836+
new_state: Option<SqlStatementAttachmentState>,
837+
existing_states: setup::CombinedState<SqlStatementAttachmentState>,
838+
context: &interface::FlowInstanceContext,
839+
) -> Result<Option<SqlStatementAttachmentSetupChange>> {
840+
let teardown_sql_to_run: IndexSet<String> = if new_state.is_none() {
841+
existing_states
842+
.possible_versions()
843+
.filter_map(|s| s.teardown_sql.clone())
844+
.collect()
845+
} else {
846+
IndexSet::new()
847+
};
848+
let setup_sql_to_run = if let Some(new_state) = new_state
849+
&& !existing_states.always_exists_and(|s| s.setup_sql == new_state.setup_sql)
850+
{
851+
Some(new_state.setup_sql)
852+
} else {
853+
None
854+
};
855+
let change = if setup_sql_to_run.is_some() || !teardown_sql_to_run.is_empty() {
856+
let db_pool = get_db_pool(target_key.database.as_ref(), &context.auth_registry).await?;
857+
Some(SqlStatementAttachmentSetupChange {
858+
db_pool,
859+
setup_sql_to_run,
860+
teardown_sql_to_run,
861+
})
862+
} else {
863+
None
864+
};
865+
Ok(change)
866+
}
867+
}
868+
869+
pub fn register(registry: &mut ExecutorFactoryRegistry) -> Result<()> {
870+
TargetFactory.register(registry)?;
871+
SqlAttachmentFactory.register(registry)?;
872+
Ok(())
873+
}

src/setup/driver.rs

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -256,9 +256,11 @@ fn group_states<K: Hash + Eq + std::fmt::Display + std::fmt::Debug + Clone, S: D
256256
Ok(grouped)
257257
}
258258

259-
fn collect_attachments_setup_change(
259+
async fn collect_attachments_setup_change(
260+
target_key: &serde_json::Value,
260261
desired: Option<&TargetSetupState>,
261262
existing: &CombinedState<TargetSetupState>,
263+
context: &interface::FlowInstanceContext,
262264
) -> Result<AttachmentsSetupChange> {
263265
let existing_current_attachments = existing
264266
.current
@@ -309,8 +311,15 @@ fn collect_attachments_setup_change(
309311
for (AttachmentSetupKey(kind, key), setup_state) in grouped_attachment_states.into_iter() {
310312
let factory = get_attachment_factory(&kind)?;
311313
let is_upsertion = setup_state.desired.is_some();
312-
if let Some(action) =
313-
factory.diff_setup_states(&key, setup_state.desired, setup_state.existing)?
314+
if let Some(action) = factory
315+
.diff_setup_states(
316+
&target_key,
317+
&key,
318+
setup_state.desired,
319+
setup_state.existing,
320+
context,
321+
)
322+
.await?
314323
{
315324
if is_upsertion {
316325
attachments_change.upserts.push(action);
@@ -411,9 +420,12 @@ pub async fn diff_flow_setup_states(
411420
};
412421

413422
let attachments_change = collect_attachments_setup_change(
423+
&resource_id.key,
414424
target_states_group.desired.as_ref(),
415425
&target_states_group.existing,
416-
)?;
426+
&flow_instance_ctx,
427+
)
428+
.await?;
417429

418430
let desired_state = target_states_group.desired.clone();
419431
let target_state = target_states_group

src/setup/states.rs

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,10 @@ impl<T> CombinedState<T> {
9191
self.current.is_some() && self.staging.iter().all(|s| !s.is_delete())
9292
}
9393

94+
pub fn always_exists_and(&self, predicate: impl Fn(&T) -> bool) -> bool {
95+
self.always_exists() && self.possible_versions().all(predicate)
96+
}
97+
9498
pub fn legacy_values<V: Ord + Eq, F: Fn(&T) -> &V>(
9599
&self,
96100
desired: Option<&T>,
@@ -420,19 +424,17 @@ pub struct TargetSetupChange {
420424
impl ResourceSetupChange for TargetSetupChange {
421425
fn describe_changes(&self) -> Vec<ChangeDescription> {
422426
let mut result = vec![];
423-
result.extend(
424-
self.attachments_change
425-
.deletes
426-
.iter()
427-
.map(|a| ChangeDescription::Action(a.describe_change())),
428-
);
427+
self.attachments_change
428+
.deletes
429+
.iter()
430+
.flat_map(|a| a.describe_changes().into_iter())
431+
.for_each(|change| result.push(ChangeDescription::Action(change)));
429432
result.extend(self.target_change.describe_changes());
430-
result.extend(
431-
self.attachments_change
432-
.upserts
433-
.iter()
434-
.map(|a| ChangeDescription::Action(a.describe_change())),
435-
);
433+
self.attachments_change
434+
.upserts
435+
.iter()
436+
.flat_map(|a| a.describe_changes().into_iter())
437+
.for_each(|change| result.push(ChangeDescription::Action(change)));
436438
result
437439
}
438440

0 commit comments

Comments
 (0)