Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions python/cocoindex/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from . import op
from .auth_registry import TransientAuthEntryReference
from .setting import DatabaseConnectionSpec
from dataclasses import dataclass
import datetime


Expand Down Expand Up @@ -70,6 +71,15 @@ class AzureBlob(op.SourceSpec):
account_access_key: TransientAuthEntryReference[str] | None = None


@dataclass
class PostgresNotification:
"""Notification for a PostgreSQL table."""

# Optional: name of the PostgreSQL channel to use.
# If not provided, will generate a default channel name.
channel_name: str | None = None


class Postgres(op.SourceSpec):
"""Import data from a PostgreSQL table."""

Expand All @@ -87,3 +97,6 @@ class Postgres(op.SourceSpec):
# Optional: column name to use for ordinal tracking (for incremental updates)
# Should be a timestamp, serial, or other incrementing column
ordinal_column: str | None = None

# Optional: when set, supports change capture from PostgreSQL notification.
notification: PostgresNotification | None = None
252 changes: 248 additions & 4 deletions src/ops/sources/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@ use crate::ops::sdk::*;

use crate::ops::shared::postgres::{bind_key_field, get_db_pool};
use crate::settings::DatabaseConnectionSpec;
use base64::Engine;
use base64::prelude::BASE64_STANDARD;
use indoc::formatdoc;
use sqlx::postgres::types::PgInterval;
use sqlx::postgres::{PgListener, PgNotification};
use sqlx::{PgPool, Row};

type PgValueDecoder = fn(&sqlx::postgres::PgRow, usize) -> Result<Value>;
Expand All @@ -13,6 +17,11 @@ struct FieldSchemaInfo {
decoder: PgValueDecoder,
}

#[derive(Debug, Clone, Deserialize)]
pub struct NotificationSpec {
channel_name: Option<String>,
}

#[derive(Debug, Deserialize)]
pub struct Spec {
/// Table name to read from (required)
Expand All @@ -23,6 +32,8 @@ pub struct Spec {
included_columns: Option<Vec<String>>,
/// Optional: ordinal column for tracking changes
ordinal_column: Option<String>,
/// Optional: notification for change capture
notification: Option<NotificationSpec>,
}

#[derive(Clone)]
Expand All @@ -33,13 +44,14 @@ struct PostgresTableSchema {
ordinal_field_schema: Option<FieldSchemaInfo>,
}

struct Executor {
struct PostgresSourceExecutor {
db_pool: PgPool,
table_name: String,
table_schema: PostgresTableSchema,
notification: Option<NotificationSpec>,
}

impl Executor {
impl PostgresSourceExecutor {
/// Append value and ordinal columns to the provided columns vector.
/// Returns the optional index of the ordinal column in the final selection.
fn build_selected_columns(
Expand Down Expand Up @@ -377,7 +389,7 @@ fn value_to_ordinal(value: &Value) -> Ordinal {
}

#[async_trait]
impl SourceExecutor for Executor {
impl SourceExecutor for PostgresSourceExecutor {
async fn list(
&self,
options: &SourceExecutorReadOptions,
Expand Down Expand Up @@ -480,11 +492,242 @@ impl SourceExecutor for Executor {
Ok(data)
}

async fn change_stream(
&self,
) -> Result<Option<BoxStream<'async_trait, Result<SourceChangeMessage>>>> {
let Some(notification_spec) = &self.notification else {
return Ok(None);
};

let channel_name = notification_spec
.channel_name
.as_ref()
.ok_or_else(|| anyhow::anyhow!("channel_name is required for change_stream"))?;
let function_name = format!("notify__{channel_name}");
let trigger_name = format!("{function_name}__trigger");

// Create the notification function
self.create_notification_function(&function_name, channel_name, &trigger_name)
.await?;

// Set up listener
let mut listener = PgListener::connect_with(&self.db_pool).await?;
listener.listen(&channel_name).await?;

let stream = try_stream! {
while let Ok(notification) = listener.recv().await {
let change = self.parse_notification_payload(&notification)?;
yield SourceChangeMessage {
changes: vec![change],
ack_fn: None,
};
}
};

Ok(Some(stream.boxed()))
}

fn provides_ordinal(&self) -> bool {
self.table_schema.ordinal_field_schema.is_some()
}
}

impl PostgresSourceExecutor {
async fn create_notification_function(
&self,
function_name: &str,
channel_name: &str,
trigger_name: &str,
) -> Result<()> {
let json_object_expr = |var: &str| {
let mut fields = (self.table_schema.primary_key_columns.iter())
.chain(self.table_schema.ordinal_field_schema.iter())
.map(|col| {
let field_name = &col.schema.name;
if matches!(
col.schema.value_type.typ,
ValueType::Basic(BasicValueType::Bytes)
) {
format!("'{field_name}', encode({var}.\"{field_name}\", 'base64')")
} else {
format!("'{field_name}', {var}.\"{field_name}\"")
}
});
format!("jsonb_build_object({})", fields.join(", "))
};

let statements = [
formatdoc! {r#"
CREATE OR REPLACE FUNCTION {function_name}() RETURNS TRIGGER AS $$
BEGIN
PERFORM pg_notify('{channel_name}', jsonb_build_object(
'op', TG_OP,
'fields',
CASE WHEN TG_OP IN ('INSERT', 'UPDATE') THEN {json_object_expr_new}
WHEN TG_OP = 'DELETE' THEN {json_object_expr_old}
ELSE NULL END
)::text);
RETURN NULL;
END;
$$ LANGUAGE plpgsql;
"#,
function_name = function_name,
channel_name = channel_name,
json_object_expr_new = json_object_expr("NEW"),
json_object_expr_old = json_object_expr("OLD"),
},
format!(
"DROP TRIGGER IF EXISTS {trigger_name} ON \"{table_name}\";",
trigger_name = trigger_name,
table_name = self.table_name,
),
formatdoc! {r#"
CREATE TRIGGER {trigger_name}
AFTER INSERT OR UPDATE OR DELETE ON "{table_name}"
FOR EACH ROW EXECUTE FUNCTION {function_name}();
"#,
trigger_name = trigger_name,
table_name = self.table_name,
function_name = function_name,
},
];

let mut tx = self.db_pool.begin().await?;
for stmt in statements {
sqlx::query(&stmt).execute(&mut *tx).await?;
}
tx.commit().await?;
Ok(())
}

fn parse_notification_payload(&self, notification: &PgNotification) -> Result<SourceChange> {
let mut payload: serde_json::Value = serde_json::from_str(notification.payload())?;
let payload = payload
.as_object_mut()
.ok_or_else(|| anyhow::anyhow!("'fields' field is not an object"))?;

let Some(serde_json::Value::String(op)) = payload.get_mut("op") else {
return Err(anyhow::anyhow!(
"Missing or invalid 'op' field in notification"
));
};
let op = std::mem::take(op);

let mut fields = std::mem::take(
payload
.get_mut("fields")
.ok_or_else(|| anyhow::anyhow!("Missing 'fields' field in notification"))?
.as_object_mut()
.ok_or_else(|| anyhow::anyhow!("'fields' field is not an object"))?,
);

// Extract primary key values to construct the key
let mut key_parts = Vec::with_capacity(self.table_schema.primary_key_columns.len());
for pk_col in &self.table_schema.primary_key_columns {
let field_value = fields.get_mut(&pk_col.schema.name).ok_or_else(|| {
anyhow::anyhow!("Missing primary key field: {}", pk_col.schema.name)
})?;

let key_part = Self::decode_key_ordinal_value_in_json(
std::mem::take(field_value),
&pk_col.schema.value_type.typ,
)?
.into_key()?;
key_parts.push(key_part);
}

let key = KeyValue(key_parts.into_boxed_slice());

// Extract ordinal if available
let ordinal = if let Some(ord_schema) = &self.table_schema.ordinal_field_schema {
if let Some(ord_value) = fields.get_mut(&ord_schema.schema.name) {
let value = Self::decode_key_ordinal_value_in_json(
std::mem::take(ord_value),
&ord_schema.schema.value_type.typ,
)?;
Some(value_to_ordinal(&value))
} else {
Some(Ordinal::unavailable())
}
} else {
None
};

let data = match op.as_str() {
"DELETE" => PartialSourceRowData {
value: Some(SourceValue::NonExistence),
ordinal,
content_version_fp: None,
},
"INSERT" | "UPDATE" => {
// For INSERT/UPDATE, we signal that the row exists but don't include the full value
// The engine will call get_value() to retrieve the actual data
PartialSourceRowData {
value: None, // Let the engine fetch the value
ordinal,
content_version_fp: None,
}
}
_ => return Err(anyhow::anyhow!("Unknown operation: {}", op)),
};

Ok(SourceChange {
key,
key_aux_info: serde_json::Value::Null,
data,
})
}

fn decode_key_ordinal_value_in_json(
json_value: serde_json::Value,
value_type: &ValueType,
) -> Result<Value> {
let result = match (value_type, json_value) {
(_, serde_json::Value::Null) => Value::Null,
(ValueType::Basic(BasicValueType::Bool), serde_json::Value::Bool(b)) => {
BasicValue::Bool(b).into()
}
(ValueType::Basic(BasicValueType::Bytes), serde_json::Value::String(s)) => {
let bytes = BASE64_STANDARD.decode(&s)?;
BasicValue::Bytes(bytes::Bytes::from(bytes)).into()
}
(ValueType::Basic(BasicValueType::Str), serde_json::Value::String(s)) => {
BasicValue::Str(s.into()).into()
}
(ValueType::Basic(BasicValueType::Int64), serde_json::Value::Number(n)) => {
if let Some(i) = n.as_i64() {
BasicValue::Int64(i).into()
} else {
bail!("Invalid integer value: {}", n)
}
}
(ValueType::Basic(BasicValueType::Uuid), serde_json::Value::String(s)) => {
let uuid = s.parse::<uuid::Uuid>()?;
BasicValue::Uuid(uuid).into()
}
(ValueType::Basic(BasicValueType::Date), serde_json::Value::String(s)) => {
let dt = s.parse::<chrono::NaiveDate>()?;
BasicValue::Date(dt).into()
}
(ValueType::Basic(BasicValueType::LocalDateTime), serde_json::Value::String(s)) => {
let dt = s.parse::<chrono::NaiveDateTime>()?;
BasicValue::LocalDateTime(dt).into()
}
(ValueType::Basic(BasicValueType::OffsetDateTime), serde_json::Value::String(s)) => {
let dt = s.parse::<chrono::DateTime<chrono::FixedOffset>>()?;
BasicValue::OffsetDateTime(dt).into()
}
(_, json_value) => {
bail!(
"Got unsupported JSON value for type {value_type}: {}",
serde_json::to_string(&json_value)?
);
}
};
Ok(result)
}
}

pub struct Factory;

#[async_trait]
Expand Down Expand Up @@ -545,10 +788,11 @@ impl SourceFactoryBase for Factory {
)
.await?;

let executor = Executor {
let executor = PostgresSourceExecutor {
db_pool,
table_name: spec.table_name.clone(),
table_schema,
notification: spec.notification.clone(),
};

Ok(Box::new(executor))
Expand Down
Loading