From b5f8f2aebbb0c97fa0e41cb5c603836dbdb99f62 Mon Sep 17 00:00:00 2001 From: LJ Date: Mon, 12 May 2025 22:47:39 -0700 Subject: [PATCH] feat(s3-push): preliminary change event support --- Cargo.lock | 24 +++ Cargo.toml | 1 + .../amazon_s3_text_embedding/.env.example | 7 +- examples/amazon_s3_text_embedding/main.py | 4 +- python/cocoindex/sources.py | 1 + src/ops/sources/amazon_s3.rs | 141 +++++++++++++++--- 6 files changed, 156 insertions(+), 22 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0a410bfc4..848cbd6dc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -337,6 +337,29 @@ dependencies = [ "url", ] +[[package]] +name = "aws-sdk-sqs" +version = "1.67.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6f15bedfb1c4385fccc474f0fe46dffb0335d0b3d6b4413df06fb30d90caba8" +dependencies = [ + "aws-credential-types", + "aws-runtime", + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-json", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "bytes", + "fastrand", + "http 0.2.12", + "once_cell", + "regex-lite", + "tracing", +] + [[package]] name = "aws-sdk-sso" version = "1.67.0" @@ -978,6 +1001,7 @@ dependencies = [ "async-trait", "aws-config", "aws-sdk-s3", + "aws-sdk-sqs", "axum", "axum-extra", "base64 0.22.1", diff --git a/Cargo.toml b/Cargo.toml index 82c5e8e30..3fc30c83a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -111,3 +111,4 @@ owo-colors = "4.2.0" json5 = "0.4.1" aws-config = "1.6.2" aws-sdk-s3 = "1.85.0" +aws-sdk-sqs = "1.67.0" diff --git a/examples/amazon_s3_text_embedding/.env.example b/examples/amazon_s3_text_embedding/.env.example index e199b294e..843822ecd 100644 --- a/examples/amazon_s3_text_embedding/.env.example +++ b/examples/amazon_s3_text_embedding/.env.example @@ -3,4 +3,9 @@ COCOINDEX_DATABASE_URL=postgres://cocoindex:cocoindex@localhost/cocoindex # Amazon S3 Configuration AMAZON_S3_BUCKET_NAME=your-bucket-name -AMAZON_S3_PREFIX=optional/prefix/path \ No newline at end of file + +# Optional +# AMAZON_S3_PREFIX= + +# Optional +# AMAZON_S3_SQS_QUEUE_URL= \ No newline at end of file diff --git a/examples/amazon_s3_text_embedding/main.py b/examples/amazon_s3_text_embedding/main.py index 0a94e02aa..c82aa136d 100644 --- a/examples/amazon_s3_text_embedding/main.py +++ b/examples/amazon_s3_text_embedding/main.py @@ -12,13 +12,15 @@ def amazon_s3_text_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scop """ bucket_name = os.environ["AMAZON_S3_BUCKET_NAME"] prefix = os.environ.get("AMAZON_S3_PREFIX", None) + sqs_queue_url = os.environ.get("AMAZON_S3_SQS_QUEUE_URL", None) data_scope["documents"] = flow_builder.add_source( cocoindex.sources.AmazonS3( bucket_name=bucket_name, prefix=prefix, included_patterns=["*.md", "*.txt", "*.docx"], - binary=False), + binary=False, + sqs_queue_url=sqs_queue_url), refresh_interval=datetime.timedelta(minutes=1)) doc_embeddings = data_scope.add_collector() diff --git a/python/cocoindex/sources.py b/python/cocoindex/sources.py index c12eac806..8356ba2e8 100644 --- a/python/cocoindex/sources.py +++ b/python/cocoindex/sources.py @@ -40,3 +40,4 @@ class AmazonS3(op.SourceSpec): binary: bool = False included_patterns: list[str] | None = None excluded_patterns: list[str] | None = None + sqs_queue_url: str | None = None \ No newline at end of file diff --git a/src/ops/sources/amazon_s3.rs b/src/ops/sources/amazon_s3.rs index cd01edc2b..5056ea1bf 100644 --- a/src/ops/sources/amazon_s3.rs +++ b/src/ops/sources/amazon_s3.rs @@ -1,7 +1,6 @@ use crate::fields_value; use async_stream::try_stream; -use aws_config::meta::region::RegionProviderChain; -use aws_config::Region; +use aws_config::BehaviorVersion; use aws_sdk_s3::Client; use globset::{Glob, GlobSet, GlobSetBuilder}; use log::warn; @@ -17,8 +16,13 @@ pub struct Spec { binary: bool, included_patterns: Option>, excluded_patterns: Option>, + sqs_queue_url: Option, } +struct SqsContext { + client: aws_sdk_sqs::Client, + queue_url: String, +} struct Executor { client: Client, bucket_name: String, @@ -26,6 +30,7 @@ struct Executor { binary: bool, included_glob_set: Option, excluded_glob_set: Option, + sqs_context: Option>, } impl Executor { @@ -53,18 +58,13 @@ impl SourceExecutor for Executor { &'a self, _options: &'a SourceExecutorListOptions, ) -> BoxStream<'a, Result>> { - let client = &self.client; - let bucket = &self.bucket_name; - let prefix = &self.prefix; - let included_glob_set = &self.included_glob_set; - let excluded_glob_set = &self.excluded_glob_set; try_stream! { let mut continuation_token = None; loop { - let mut req = client + let mut req = self.client .list_objects_v2() - .bucket(bucket); - if let Some(ref p) = prefix { + .bucket(&self.bucket_name); + if let Some(ref p) = self.prefix { req = req.prefix(p); } if let Some(ref token) = continuation_token { @@ -77,11 +77,11 @@ impl SourceExecutor for Executor { if let Some(key) = obj.key() { // Only include files (not folders) if key.ends_with('/') { continue; } - let include = included_glob_set + let include = self.included_glob_set .as_ref() .map(|gs| gs.is_match(key)) .unwrap_or(true); - let exclude = excluded_glob_set + let exclude = self.excluded_glob_set .as_ref() .map(|gs| gs.is_match(key)) .unwrap_or(false); @@ -152,6 +152,107 @@ impl SourceExecutor for Executor { }; Ok(Some(SourceValue { value, ordinal })) } + + async fn change_stream(&self) -> Result>> { + let sqs_context = if let Some(sqs_context) = &self.sqs_context { + sqs_context + } else { + return Ok(None); + }; + let stream = stream! { + loop { + let changes = match self.poll_sqs(&sqs_context).await { + Ok(changes) => changes, + Err(e) => { + warn!("Failed to poll SQS: {}", e); + continue; + } + }; + for change in changes { + yield change; + } + } + }; + Ok(Some(stream.boxed())) + } +} + +#[derive(Debug, Deserialize)] +pub struct S3EventNotification { + #[serde(rename = "Records")] + pub records: Vec, +} + +#[derive(Debug, Deserialize)] +pub struct S3EventRecord { + #[serde(rename = "eventName")] + pub event_name: String, + // pub eventTime: String, + pub s3: S3Entity, +} + +#[derive(Debug, Deserialize)] +pub struct S3Entity { + pub bucket: S3Bucket, + pub object: S3Object, +} + +#[derive(Debug, Deserialize)] +pub struct S3Bucket { + pub name: String, +} + +#[derive(Debug, Deserialize)] +pub struct S3Object { + pub key: String, +} + +impl Executor { + async fn poll_sqs(&self, sqs_context: &Arc) -> Result> { + let resp = sqs_context + .client + .receive_message() + .queue_url(&sqs_context.queue_url) + .max_number_of_messages(10) + .wait_time_seconds(20) + .send() + .await?; + let messages = if let Some(messages) = resp.messages { + messages + } else { + return Ok(Vec::new()); + }; + let mut changes = vec![]; + for message in messages.into_iter().filter_map(|m| m.body) { + let notification: S3EventNotification = serde_json::from_str(&message)?; + for record in notification.records { + if record.s3.bucket.name != self.bucket_name { + continue; + } + if !self + .prefix + .as_ref() + .map_or(true, |prefix| record.s3.object.key.starts_with(prefix)) + { + continue; + } + if record.event_name.starts_with("ObjectCreated:") { + changes.push(SourceChange { + key: KeyValue::Str(record.s3.object.key.into()), + ordinal: None, + value: SourceValueChange::Upsert(None), + }); + } else if record.event_name.starts_with("ObjectDeleted:") { + changes.push(SourceChange { + key: KeyValue::Str(record.s3.object.key.into()), + ordinal: None, + value: SourceValueChange::Delete, + }); + } + } + } + Ok(changes) + } } pub struct Factory; @@ -198,20 +299,20 @@ impl SourceFactoryBase for Factory { spec: Spec, _context: Arc, ) -> Result> { - let region_provider = - RegionProviderChain::default_provider().or_else(Region::new("us-east-1")); - let config = aws_config::defaults(aws_config::BehaviorVersion::latest()) - .region(region_provider) - .load() - .await; - let client = Client::new(&config); + let config = aws_config::load_defaults(BehaviorVersion::latest()).await; Ok(Box::new(Executor { - client, + client: Client::new(&config), bucket_name: spec.bucket_name, prefix: spec.prefix, binary: spec.binary, included_glob_set: spec.included_patterns.map(build_glob_set).transpose()?, excluded_glob_set: spec.excluded_patterns.map(build_glob_set).transpose()?, + sqs_context: spec.sqs_queue_url.map(|url| { + Arc::new(SqsContext { + client: aws_sdk_sqs::Client::new(&config), + queue_url: url, + }) + }), })) } }