diff --git a/python/cocoindex/sources.py b/python/cocoindex/sources.py index 0850d9be4..df409b527 100644 --- a/python/cocoindex/sources.py +++ b/python/cocoindex/sources.py @@ -3,6 +3,7 @@ from . import op from .auth_registry import TransientAuthEntryReference from .setting import DatabaseConnectionSpec +from dataclasses import dataclass import datetime @@ -70,6 +71,15 @@ class AzureBlob(op.SourceSpec): account_access_key: TransientAuthEntryReference[str] | None = None +@dataclass +class PostgresNotification: + """Notification for a PostgreSQL table.""" + + # Optional: name of the PostgreSQL channel to use. + # If not provided, will generate a default channel name. + channel_name: str | None = None + + class Postgres(op.SourceSpec): """Import data from a PostgreSQL table.""" @@ -87,3 +97,6 @@ class Postgres(op.SourceSpec): # Optional: column name to use for ordinal tracking (for incremental updates) # Should be a timestamp, serial, or other incrementing column ordinal_column: str | None = None + + # Optional: when set, supports change capture from PostgreSQL notification. + notification: PostgresNotification | None = None diff --git a/src/ops/sources/postgres.rs b/src/ops/sources/postgres.rs index d5ba4b88e..2755f08ff 100644 --- a/src/ops/sources/postgres.rs +++ b/src/ops/sources/postgres.rs @@ -2,7 +2,11 @@ use crate::ops::sdk::*; use crate::ops::shared::postgres::{bind_key_field, get_db_pool}; use crate::settings::DatabaseConnectionSpec; +use base64::Engine; +use base64::prelude::BASE64_STANDARD; +use indoc::formatdoc; use sqlx::postgres::types::PgInterval; +use sqlx::postgres::{PgListener, PgNotification}; use sqlx::{PgPool, Row}; type PgValueDecoder = fn(&sqlx::postgres::PgRow, usize) -> Result; @@ -13,6 +17,11 @@ struct FieldSchemaInfo { decoder: PgValueDecoder, } +#[derive(Debug, Clone, Deserialize)] +pub struct NotificationSpec { + channel_name: Option, +} + #[derive(Debug, Deserialize)] pub struct Spec { /// Table name to read from (required) @@ -23,6 +32,8 @@ pub struct Spec { included_columns: Option>, /// Optional: ordinal column for tracking changes ordinal_column: Option, + /// Optional: notification for change capture + notification: Option, } #[derive(Clone)] @@ -33,13 +44,14 @@ struct PostgresTableSchema { ordinal_field_schema: Option, } -struct Executor { +struct PostgresSourceExecutor { db_pool: PgPool, table_name: String, table_schema: PostgresTableSchema, + notification: Option, } -impl Executor { +impl PostgresSourceExecutor { /// Append value and ordinal columns to the provided columns vector. /// Returns the optional index of the ordinal column in the final selection. fn build_selected_columns( @@ -377,7 +389,7 @@ fn value_to_ordinal(value: &Value) -> Ordinal { } #[async_trait] -impl SourceExecutor for Executor { +impl SourceExecutor for PostgresSourceExecutor { async fn list( &self, options: &SourceExecutorReadOptions, @@ -480,11 +492,242 @@ impl SourceExecutor for Executor { Ok(data) } + async fn change_stream( + &self, + ) -> Result>>> { + let Some(notification_spec) = &self.notification else { + return Ok(None); + }; + + let channel_name = notification_spec + .channel_name + .as_ref() + .ok_or_else(|| anyhow::anyhow!("channel_name is required for change_stream"))?; + let function_name = format!("notify__{channel_name}"); + let trigger_name = format!("{function_name}__trigger"); + + // Create the notification function + self.create_notification_function(&function_name, channel_name, &trigger_name) + .await?; + + // Set up listener + let mut listener = PgListener::connect_with(&self.db_pool).await?; + listener.listen(&channel_name).await?; + + let stream = try_stream! { + while let Ok(notification) = listener.recv().await { + let change = self.parse_notification_payload(¬ification)?; + yield SourceChangeMessage { + changes: vec![change], + ack_fn: None, + }; + } + }; + + Ok(Some(stream.boxed())) + } + fn provides_ordinal(&self) -> bool { self.table_schema.ordinal_field_schema.is_some() } } +impl PostgresSourceExecutor { + async fn create_notification_function( + &self, + function_name: &str, + channel_name: &str, + trigger_name: &str, + ) -> Result<()> { + let json_object_expr = |var: &str| { + let mut fields = (self.table_schema.primary_key_columns.iter()) + .chain(self.table_schema.ordinal_field_schema.iter()) + .map(|col| { + let field_name = &col.schema.name; + if matches!( + col.schema.value_type.typ, + ValueType::Basic(BasicValueType::Bytes) + ) { + format!("'{field_name}', encode({var}.\"{field_name}\", 'base64')") + } else { + format!("'{field_name}', {var}.\"{field_name}\"") + } + }); + format!("jsonb_build_object({})", fields.join(", ")) + }; + + let statements = [ + formatdoc! {r#" + CREATE OR REPLACE FUNCTION {function_name}() RETURNS TRIGGER AS $$ + BEGIN + PERFORM pg_notify('{channel_name}', jsonb_build_object( + 'op', TG_OP, + 'fields', + CASE WHEN TG_OP IN ('INSERT', 'UPDATE') THEN {json_object_expr_new} + WHEN TG_OP = 'DELETE' THEN {json_object_expr_old} + ELSE NULL END + )::text); + RETURN NULL; + END; + $$ LANGUAGE plpgsql; + "#, + function_name = function_name, + channel_name = channel_name, + json_object_expr_new = json_object_expr("NEW"), + json_object_expr_old = json_object_expr("OLD"), + }, + format!( + "DROP TRIGGER IF EXISTS {trigger_name} ON \"{table_name}\";", + trigger_name = trigger_name, + table_name = self.table_name, + ), + formatdoc! {r#" + CREATE TRIGGER {trigger_name} + AFTER INSERT OR UPDATE OR DELETE ON "{table_name}" + FOR EACH ROW EXECUTE FUNCTION {function_name}(); + "#, + trigger_name = trigger_name, + table_name = self.table_name, + function_name = function_name, + }, + ]; + + let mut tx = self.db_pool.begin().await?; + for stmt in statements { + sqlx::query(&stmt).execute(&mut *tx).await?; + } + tx.commit().await?; + Ok(()) + } + + fn parse_notification_payload(&self, notification: &PgNotification) -> Result { + let mut payload: serde_json::Value = serde_json::from_str(notification.payload())?; + let payload = payload + .as_object_mut() + .ok_or_else(|| anyhow::anyhow!("'fields' field is not an object"))?; + + let Some(serde_json::Value::String(op)) = payload.get_mut("op") else { + return Err(anyhow::anyhow!( + "Missing or invalid 'op' field in notification" + )); + }; + let op = std::mem::take(op); + + let mut fields = std::mem::take( + payload + .get_mut("fields") + .ok_or_else(|| anyhow::anyhow!("Missing 'fields' field in notification"))? + .as_object_mut() + .ok_or_else(|| anyhow::anyhow!("'fields' field is not an object"))?, + ); + + // Extract primary key values to construct the key + let mut key_parts = Vec::with_capacity(self.table_schema.primary_key_columns.len()); + for pk_col in &self.table_schema.primary_key_columns { + let field_value = fields.get_mut(&pk_col.schema.name).ok_or_else(|| { + anyhow::anyhow!("Missing primary key field: {}", pk_col.schema.name) + })?; + + let key_part = Self::decode_key_ordinal_value_in_json( + std::mem::take(field_value), + &pk_col.schema.value_type.typ, + )? + .into_key()?; + key_parts.push(key_part); + } + + let key = KeyValue(key_parts.into_boxed_slice()); + + // Extract ordinal if available + let ordinal = if let Some(ord_schema) = &self.table_schema.ordinal_field_schema { + if let Some(ord_value) = fields.get_mut(&ord_schema.schema.name) { + let value = Self::decode_key_ordinal_value_in_json( + std::mem::take(ord_value), + &ord_schema.schema.value_type.typ, + )?; + Some(value_to_ordinal(&value)) + } else { + Some(Ordinal::unavailable()) + } + } else { + None + }; + + let data = match op.as_str() { + "DELETE" => PartialSourceRowData { + value: Some(SourceValue::NonExistence), + ordinal, + content_version_fp: None, + }, + "INSERT" | "UPDATE" => { + // For INSERT/UPDATE, we signal that the row exists but don't include the full value + // The engine will call get_value() to retrieve the actual data + PartialSourceRowData { + value: None, // Let the engine fetch the value + ordinal, + content_version_fp: None, + } + } + _ => return Err(anyhow::anyhow!("Unknown operation: {}", op)), + }; + + Ok(SourceChange { + key, + key_aux_info: serde_json::Value::Null, + data, + }) + } + + fn decode_key_ordinal_value_in_json( + json_value: serde_json::Value, + value_type: &ValueType, + ) -> Result { + let result = match (value_type, json_value) { + (_, serde_json::Value::Null) => Value::Null, + (ValueType::Basic(BasicValueType::Bool), serde_json::Value::Bool(b)) => { + BasicValue::Bool(b).into() + } + (ValueType::Basic(BasicValueType::Bytes), serde_json::Value::String(s)) => { + let bytes = BASE64_STANDARD.decode(&s)?; + BasicValue::Bytes(bytes::Bytes::from(bytes)).into() + } + (ValueType::Basic(BasicValueType::Str), serde_json::Value::String(s)) => { + BasicValue::Str(s.into()).into() + } + (ValueType::Basic(BasicValueType::Int64), serde_json::Value::Number(n)) => { + if let Some(i) = n.as_i64() { + BasicValue::Int64(i).into() + } else { + bail!("Invalid integer value: {}", n) + } + } + (ValueType::Basic(BasicValueType::Uuid), serde_json::Value::String(s)) => { + let uuid = s.parse::()?; + BasicValue::Uuid(uuid).into() + } + (ValueType::Basic(BasicValueType::Date), serde_json::Value::String(s)) => { + let dt = s.parse::()?; + BasicValue::Date(dt).into() + } + (ValueType::Basic(BasicValueType::LocalDateTime), serde_json::Value::String(s)) => { + let dt = s.parse::()?; + BasicValue::LocalDateTime(dt).into() + } + (ValueType::Basic(BasicValueType::OffsetDateTime), serde_json::Value::String(s)) => { + let dt = s.parse::>()?; + BasicValue::OffsetDateTime(dt).into() + } + (_, json_value) => { + bail!( + "Got unsupported JSON value for type {value_type}: {}", + serde_json::to_string(&json_value)? + ); + } + }; + Ok(result) + } +} + pub struct Factory; #[async_trait] @@ -545,10 +788,11 @@ impl SourceFactoryBase for Factory { ) .await?; - let executor = Executor { + let executor = PostgresSourceExecutor { db_pool, table_name: spec.table_name.clone(), table_schema, + notification: spec.notification.clone(), }; Ok(Box::new(executor))