Skip to content

Commit 0b29140

Browse files
committed
feat(postgres): add preliminary support for change notification
1 parent 85453e1 commit 0b29140

File tree

2 files changed

+248
-9
lines changed

2 files changed

+248
-9
lines changed

python/cocoindex/sources.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from . import op
44
from .auth_registry import TransientAuthEntryReference
55
from .setting import DatabaseConnectionSpec
6+
from dataclasses import dataclass
67
import datetime
78

89

@@ -70,12 +71,13 @@ class AzureBlob(op.SourceSpec):
7071
account_access_key: TransientAuthEntryReference[str] | None = None
7172

7273

74+
@dataclass
7375
class PostgresNotification:
7476
"""Notification for a PostgreSQL table."""
7577

76-
# Optional: name of the PostgreSQL notify function to use.
77-
# If not provided, the default notify function will be used.
78-
notify_function_name: str | None = None
78+
# Optional: name of the PostgreSQL channel to use.
79+
# If not provided, will generate a default channel name.
80+
channel_name: str | None = None
7981

8082

8183
class Postgres(op.SourceSpec):

src/ops/sources/postgres.rs

Lines changed: 243 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@ use crate::ops::sdk::*;
22

33
use crate::ops::shared::postgres::{bind_key_field, get_db_pool};
44
use crate::settings::DatabaseConnectionSpec;
5+
use base64::Engine;
6+
use base64::prelude::BASE64_STANDARD;
7+
use indoc::formatdoc;
58
use sqlx::postgres::types::PgInterval;
9+
use sqlx::postgres::{PgListener, PgNotification};
610
use sqlx::{PgPool, Row};
711

812
type PgValueDecoder = fn(&sqlx::postgres::PgRow, usize) -> Result<Value>;
@@ -13,9 +17,9 @@ struct FieldSchemaInfo {
1317
decoder: PgValueDecoder,
1418
}
1519

16-
#[derive(Debug, Deserialize)]
20+
#[derive(Debug, Clone, Deserialize)]
1721
pub struct NotificationSpec {
18-
notify_function_name: Option<String>,
22+
channel_name: Option<String>,
1923
}
2024

2125
#[derive(Debug, Deserialize)]
@@ -40,13 +44,14 @@ struct PostgresTableSchema {
4044
ordinal_field_schema: Option<FieldSchemaInfo>,
4145
}
4246

43-
struct Executor {
47+
struct PostgresSourceExecutor {
4448
db_pool: PgPool,
4549
table_name: String,
4650
table_schema: PostgresTableSchema,
51+
notification: Option<NotificationSpec>,
4752
}
4853

49-
impl Executor {
54+
impl PostgresSourceExecutor {
5055
/// Append value and ordinal columns to the provided columns vector.
5156
/// Returns the optional index of the ordinal column in the final selection.
5257
fn build_selected_columns(
@@ -384,7 +389,7 @@ fn value_to_ordinal(value: &Value) -> Ordinal {
384389
}
385390

386391
#[async_trait]
387-
impl SourceExecutor for Executor {
392+
impl SourceExecutor for PostgresSourceExecutor {
388393
async fn list(
389394
&self,
390395
options: &SourceExecutorReadOptions,
@@ -487,11 +492,242 @@ impl SourceExecutor for Executor {
487492
Ok(data)
488493
}
489494

495+
async fn change_stream(
496+
&self,
497+
) -> Result<Option<BoxStream<'async_trait, Result<SourceChangeMessage>>>> {
498+
let Some(notification_spec) = &self.notification else {
499+
return Ok(None);
500+
};
501+
502+
let channel_name = notification_spec
503+
.channel_name
504+
.as_ref()
505+
.ok_or_else(|| anyhow::anyhow!("channel_name is required for change_stream"))?;
506+
let function_name = format!("notify__{channel_name}");
507+
let trigger_name = format!("{function_name}__trigger");
508+
509+
// Create the notification function
510+
self.create_notification_function(&function_name, channel_name, &trigger_name)
511+
.await?;
512+
513+
// Set up listener
514+
let mut listener = PgListener::connect_with(&self.db_pool).await?;
515+
listener.listen(&channel_name).await?;
516+
517+
let stream = try_stream! {
518+
while let Ok(notification) = listener.recv().await {
519+
let change = self.parse_notification_payload(&notification)?;
520+
yield SourceChangeMessage {
521+
changes: vec![change],
522+
ack_fn: None,
523+
};
524+
}
525+
};
526+
527+
Ok(Some(stream.boxed()))
528+
}
529+
490530
fn provides_ordinal(&self) -> bool {
491531
self.table_schema.ordinal_field_schema.is_some()
492532
}
493533
}
494534

535+
impl PostgresSourceExecutor {
536+
async fn create_notification_function(
537+
&self,
538+
function_name: &str,
539+
channel_name: &str,
540+
trigger_name: &str,
541+
) -> Result<()> {
542+
let json_object_expr = |var: &str| {
543+
let mut fields = (self.table_schema.primary_key_columns.iter())
544+
.chain(self.table_schema.ordinal_field_schema.iter())
545+
.map(|col| {
546+
let field_name = &col.schema.name;
547+
if matches!(
548+
col.schema.value_type.typ,
549+
ValueType::Basic(BasicValueType::Bytes)
550+
) {
551+
format!("'{field_name}', encode({var}.\"{field_name}\", 'base64')")
552+
} else {
553+
format!("'{field_name}', {var}.\"{field_name}\"")
554+
}
555+
});
556+
format!("jsonb_build_object({})", fields.join(", "))
557+
};
558+
559+
let statements = [
560+
formatdoc! {r#"
561+
CREATE OR REPLACE FUNCTION {function_name}() RETURNS TRIGGER AS $$
562+
BEGIN
563+
PERFORM pg_notify('{channel_name}', jsonb_build_object(
564+
'op', TG_OP,
565+
'fields',
566+
CASE WHEN TG_OP IN ('INSERT', 'UPDATE') THEN {json_object_expr_new}
567+
WHEN TG_OP = 'DELETE' THEN {json_object_expr_old}
568+
ELSE NULL END
569+
)::text);
570+
RETURN NULL;
571+
END;
572+
$$ LANGUAGE plpgsql;
573+
"#,
574+
function_name = function_name,
575+
channel_name = channel_name,
576+
json_object_expr_new = json_object_expr("NEW"),
577+
json_object_expr_old = json_object_expr("OLD"),
578+
},
579+
format!(
580+
"DROP TRIGGER IF EXISTS {trigger_name} ON \"{table_name}\";",
581+
trigger_name = trigger_name,
582+
table_name = self.table_name,
583+
),
584+
formatdoc! {r#"
585+
CREATE TRIGGER {trigger_name}
586+
AFTER INSERT OR UPDATE OR DELETE ON "{table_name}"
587+
FOR EACH ROW EXECUTE FUNCTION {function_name}();
588+
"#,
589+
trigger_name = trigger_name,
590+
table_name = self.table_name,
591+
function_name = function_name,
592+
},
593+
];
594+
595+
let mut tx = self.db_pool.begin().await?;
596+
for stmt in statements {
597+
sqlx::query(&stmt).execute(&mut *tx).await?;
598+
}
599+
tx.commit().await?;
600+
Ok(())
601+
}
602+
603+
fn parse_notification_payload(&self, notification: &PgNotification) -> Result<SourceChange> {
604+
let mut payload: serde_json::Value = serde_json::from_str(notification.payload())?;
605+
let payload = payload
606+
.as_object_mut()
607+
.ok_or_else(|| anyhow::anyhow!("'fields' field is not an object"))?;
608+
609+
let Some(serde_json::Value::String(op)) = payload.get_mut("op") else {
610+
return Err(anyhow::anyhow!(
611+
"Missing or invalid 'op' field in notification"
612+
));
613+
};
614+
let op = std::mem::take(op);
615+
616+
let mut fields = std::mem::take(
617+
payload
618+
.get_mut("fields")
619+
.ok_or_else(|| anyhow::anyhow!("Missing 'fields' field in notification"))?
620+
.as_object_mut()
621+
.ok_or_else(|| anyhow::anyhow!("'fields' field is not an object"))?,
622+
);
623+
624+
// Extract primary key values to construct the key
625+
let mut key_parts = Vec::with_capacity(self.table_schema.primary_key_columns.len());
626+
for pk_col in &self.table_schema.primary_key_columns {
627+
let field_value = fields.get_mut(&pk_col.schema.name).ok_or_else(|| {
628+
anyhow::anyhow!("Missing primary key field: {}", pk_col.schema.name)
629+
})?;
630+
631+
let key_part = Self::decode_key_ordinal_value_in_json(
632+
std::mem::take(field_value),
633+
&pk_col.schema.value_type.typ,
634+
)?
635+
.into_key()?;
636+
key_parts.push(key_part);
637+
}
638+
639+
let key = KeyValue(key_parts.into_boxed_slice());
640+
641+
// Extract ordinal if available
642+
let ordinal = if let Some(ord_schema) = &self.table_schema.ordinal_field_schema {
643+
if let Some(ord_value) = fields.get_mut(&ord_schema.schema.name) {
644+
let value = Self::decode_key_ordinal_value_in_json(
645+
std::mem::take(ord_value),
646+
&ord_schema.schema.value_type.typ,
647+
)?;
648+
Some(value_to_ordinal(&value))
649+
} else {
650+
Some(Ordinal::unavailable())
651+
}
652+
} else {
653+
None
654+
};
655+
656+
let data = match op.as_str() {
657+
"DELETE" => PartialSourceRowData {
658+
value: Some(SourceValue::NonExistence),
659+
ordinal,
660+
content_version_fp: None,
661+
},
662+
"INSERT" | "UPDATE" => {
663+
// For INSERT/UPDATE, we signal that the row exists but don't include the full value
664+
// The engine will call get_value() to retrieve the actual data
665+
PartialSourceRowData {
666+
value: None, // Let the engine fetch the value
667+
ordinal,
668+
content_version_fp: None,
669+
}
670+
}
671+
_ => return Err(anyhow::anyhow!("Unknown operation: {}", op)),
672+
};
673+
674+
Ok(SourceChange {
675+
key,
676+
key_aux_info: serde_json::Value::Null,
677+
data,
678+
})
679+
}
680+
681+
fn decode_key_ordinal_value_in_json(
682+
json_value: serde_json::Value,
683+
value_type: &ValueType,
684+
) -> Result<Value> {
685+
let result = match (value_type, json_value) {
686+
(_, serde_json::Value::Null) => Value::Null,
687+
(ValueType::Basic(BasicValueType::Bool), serde_json::Value::Bool(b)) => {
688+
BasicValue::Bool(b).into()
689+
}
690+
(ValueType::Basic(BasicValueType::Bytes), serde_json::Value::String(s)) => {
691+
let bytes = BASE64_STANDARD.decode(&s)?;
692+
BasicValue::Bytes(bytes::Bytes::from(bytes)).into()
693+
}
694+
(ValueType::Basic(BasicValueType::Str), serde_json::Value::String(s)) => {
695+
BasicValue::Str(s.into()).into()
696+
}
697+
(ValueType::Basic(BasicValueType::Int64), serde_json::Value::Number(n)) => {
698+
if let Some(i) = n.as_i64() {
699+
BasicValue::Int64(i).into()
700+
} else {
701+
bail!("Invalid integer value: {}", n)
702+
}
703+
}
704+
(ValueType::Basic(BasicValueType::Uuid), serde_json::Value::String(s)) => {
705+
let uuid = s.parse::<uuid::Uuid>()?;
706+
BasicValue::Uuid(uuid).into()
707+
}
708+
(ValueType::Basic(BasicValueType::Date), serde_json::Value::String(s)) => {
709+
let dt = s.parse::<chrono::NaiveDate>()?;
710+
BasicValue::Date(dt).into()
711+
}
712+
(ValueType::Basic(BasicValueType::LocalDateTime), serde_json::Value::String(s)) => {
713+
let dt = s.parse::<chrono::NaiveDateTime>()?;
714+
BasicValue::LocalDateTime(dt).into()
715+
}
716+
(ValueType::Basic(BasicValueType::OffsetDateTime), serde_json::Value::String(s)) => {
717+
let dt = s.parse::<chrono::DateTime<chrono::FixedOffset>>()?;
718+
BasicValue::OffsetDateTime(dt).into()
719+
}
720+
(_, json_value) => {
721+
bail!(
722+
"Got unsupported JSON value for type {value_type}: {}",
723+
serde_json::to_string(&json_value)?
724+
);
725+
}
726+
};
727+
Ok(result)
728+
}
729+
}
730+
495731
pub struct Factory;
496732

497733
#[async_trait]
@@ -552,10 +788,11 @@ impl SourceFactoryBase for Factory {
552788
)
553789
.await?;
554790

555-
let executor = Executor {
791+
let executor = PostgresSourceExecutor {
556792
db_pool,
557793
table_name: spec.table_name.clone(),
558794
table_schema,
795+
notification: spec.notification.clone(),
559796
};
560797

561798
Ok(Box::new(executor))

0 commit comments

Comments
 (0)