Skip to content

Commit b5f8f2a

Browse files
committed
feat(s3-push): preliminary change event support
1 parent cd94481 commit b5f8f2a

File tree

6 files changed

+156
-22
lines changed

6 files changed

+156
-22
lines changed

Cargo.lock

Lines changed: 24 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,4 @@ owo-colors = "4.2.0"
111111
json5 = "0.4.1"
112112
aws-config = "1.6.2"
113113
aws-sdk-s3 = "1.85.0"
114+
aws-sdk-sqs = "1.67.0"

examples/amazon_s3_text_embedding/.env.example

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,9 @@ COCOINDEX_DATABASE_URL=postgres://cocoindex:cocoindex@localhost/cocoindex
33

44
# Amazon S3 Configuration
55
AMAZON_S3_BUCKET_NAME=your-bucket-name
6-
AMAZON_S3_PREFIX=optional/prefix/path
6+
7+
# Optional
8+
# AMAZON_S3_PREFIX=
9+
10+
# Optional
11+
# AMAZON_S3_SQS_QUEUE_URL=

examples/amazon_s3_text_embedding/main.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,15 @@ def amazon_s3_text_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scop
1212
"""
1313
bucket_name = os.environ["AMAZON_S3_BUCKET_NAME"]
1414
prefix = os.environ.get("AMAZON_S3_PREFIX", None)
15+
sqs_queue_url = os.environ.get("AMAZON_S3_SQS_QUEUE_URL", None)
1516

1617
data_scope["documents"] = flow_builder.add_source(
1718
cocoindex.sources.AmazonS3(
1819
bucket_name=bucket_name,
1920
prefix=prefix,
2021
included_patterns=["*.md", "*.txt", "*.docx"],
21-
binary=False),
22+
binary=False,
23+
sqs_queue_url=sqs_queue_url),
2224
refresh_interval=datetime.timedelta(minutes=1))
2325

2426
doc_embeddings = data_scope.add_collector()

python/cocoindex/sources.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,4 @@ class AmazonS3(op.SourceSpec):
4040
binary: bool = False
4141
included_patterns: list[str] | None = None
4242
excluded_patterns: list[str] | None = None
43+
sqs_queue_url: str | None = None

src/ops/sources/amazon_s3.rs

Lines changed: 121 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
use crate::fields_value;
22
use async_stream::try_stream;
3-
use aws_config::meta::region::RegionProviderChain;
4-
use aws_config::Region;
3+
use aws_config::BehaviorVersion;
54
use aws_sdk_s3::Client;
65
use globset::{Glob, GlobSet, GlobSetBuilder};
76
use log::warn;
@@ -17,15 +16,21 @@ pub struct Spec {
1716
binary: bool,
1817
included_patterns: Option<Vec<String>>,
1918
excluded_patterns: Option<Vec<String>>,
19+
sqs_queue_url: Option<String>,
2020
}
2121

22+
struct SqsContext {
23+
client: aws_sdk_sqs::Client,
24+
queue_url: String,
25+
}
2226
struct Executor {
2327
client: Client,
2428
bucket_name: String,
2529
prefix: Option<String>,
2630
binary: bool,
2731
included_glob_set: Option<GlobSet>,
2832
excluded_glob_set: Option<GlobSet>,
33+
sqs_context: Option<Arc<SqsContext>>,
2934
}
3035

3136
impl Executor {
@@ -53,18 +58,13 @@ impl SourceExecutor for Executor {
5358
&'a self,
5459
_options: &'a SourceExecutorListOptions,
5560
) -> BoxStream<'a, Result<Vec<SourceRowMetadata>>> {
56-
let client = &self.client;
57-
let bucket = &self.bucket_name;
58-
let prefix = &self.prefix;
59-
let included_glob_set = &self.included_glob_set;
60-
let excluded_glob_set = &self.excluded_glob_set;
6161
try_stream! {
6262
let mut continuation_token = None;
6363
loop {
64-
let mut req = client
64+
let mut req = self.client
6565
.list_objects_v2()
66-
.bucket(bucket);
67-
if let Some(ref p) = prefix {
66+
.bucket(&self.bucket_name);
67+
if let Some(ref p) = self.prefix {
6868
req = req.prefix(p);
6969
}
7070
if let Some(ref token) = continuation_token {
@@ -77,11 +77,11 @@ impl SourceExecutor for Executor {
7777
if let Some(key) = obj.key() {
7878
// Only include files (not folders)
7979
if key.ends_with('/') { continue; }
80-
let include = included_glob_set
80+
let include = self.included_glob_set
8181
.as_ref()
8282
.map(|gs| gs.is_match(key))
8383
.unwrap_or(true);
84-
let exclude = excluded_glob_set
84+
let exclude = self.excluded_glob_set
8585
.as_ref()
8686
.map(|gs| gs.is_match(key))
8787
.unwrap_or(false);
@@ -152,6 +152,107 @@ impl SourceExecutor for Executor {
152152
};
153153
Ok(Some(SourceValue { value, ordinal }))
154154
}
155+
156+
async fn change_stream(&self) -> Result<Option<BoxStream<'async_trait, SourceChange>>> {
157+
let sqs_context = if let Some(sqs_context) = &self.sqs_context {
158+
sqs_context
159+
} else {
160+
return Ok(None);
161+
};
162+
let stream = stream! {
163+
loop {
164+
let changes = match self.poll_sqs(&sqs_context).await {
165+
Ok(changes) => changes,
166+
Err(e) => {
167+
warn!("Failed to poll SQS: {}", e);
168+
continue;
169+
}
170+
};
171+
for change in changes {
172+
yield change;
173+
}
174+
}
175+
};
176+
Ok(Some(stream.boxed()))
177+
}
178+
}
179+
180+
#[derive(Debug, Deserialize)]
181+
pub struct S3EventNotification {
182+
#[serde(rename = "Records")]
183+
pub records: Vec<S3EventRecord>,
184+
}
185+
186+
#[derive(Debug, Deserialize)]
187+
pub struct S3EventRecord {
188+
#[serde(rename = "eventName")]
189+
pub event_name: String,
190+
// pub eventTime: String,
191+
pub s3: S3Entity,
192+
}
193+
194+
#[derive(Debug, Deserialize)]
195+
pub struct S3Entity {
196+
pub bucket: S3Bucket,
197+
pub object: S3Object,
198+
}
199+
200+
#[derive(Debug, Deserialize)]
201+
pub struct S3Bucket {
202+
pub name: String,
203+
}
204+
205+
#[derive(Debug, Deserialize)]
206+
pub struct S3Object {
207+
pub key: String,
208+
}
209+
210+
impl Executor {
211+
async fn poll_sqs(&self, sqs_context: &Arc<SqsContext>) -> Result<Vec<SourceChange>> {
212+
let resp = sqs_context
213+
.client
214+
.receive_message()
215+
.queue_url(&sqs_context.queue_url)
216+
.max_number_of_messages(10)
217+
.wait_time_seconds(20)
218+
.send()
219+
.await?;
220+
let messages = if let Some(messages) = resp.messages {
221+
messages
222+
} else {
223+
return Ok(Vec::new());
224+
};
225+
let mut changes = vec![];
226+
for message in messages.into_iter().filter_map(|m| m.body) {
227+
let notification: S3EventNotification = serde_json::from_str(&message)?;
228+
for record in notification.records {
229+
if record.s3.bucket.name != self.bucket_name {
230+
continue;
231+
}
232+
if !self
233+
.prefix
234+
.as_ref()
235+
.map_or(true, |prefix| record.s3.object.key.starts_with(prefix))
236+
{
237+
continue;
238+
}
239+
if record.event_name.starts_with("ObjectCreated:") {
240+
changes.push(SourceChange {
241+
key: KeyValue::Str(record.s3.object.key.into()),
242+
ordinal: None,
243+
value: SourceValueChange::Upsert(None),
244+
});
245+
} else if record.event_name.starts_with("ObjectDeleted:") {
246+
changes.push(SourceChange {
247+
key: KeyValue::Str(record.s3.object.key.into()),
248+
ordinal: None,
249+
value: SourceValueChange::Delete,
250+
});
251+
}
252+
}
253+
}
254+
Ok(changes)
255+
}
155256
}
156257

157258
pub struct Factory;
@@ -198,20 +299,20 @@ impl SourceFactoryBase for Factory {
198299
spec: Spec,
199300
_context: Arc<FlowInstanceContext>,
200301
) -> Result<Box<dyn SourceExecutor>> {
201-
let region_provider =
202-
RegionProviderChain::default_provider().or_else(Region::new("us-east-1"));
203-
let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
204-
.region(region_provider)
205-
.load()
206-
.await;
207-
let client = Client::new(&config);
302+
let config = aws_config::load_defaults(BehaviorVersion::latest()).await;
208303
Ok(Box::new(Executor {
209-
client,
304+
client: Client::new(&config),
210305
bucket_name: spec.bucket_name,
211306
prefix: spec.prefix,
212307
binary: spec.binary,
213308
included_glob_set: spec.included_patterns.map(build_glob_set).transpose()?,
214309
excluded_glob_set: spec.excluded_patterns.map(build_glob_set).transpose()?,
310+
sqs_context: spec.sqs_queue_url.map(|url| {
311+
Arc::new(SqsContext {
312+
client: aws_sdk_sqs::Client::new(&config),
313+
queue_url: url,
314+
})
315+
}),
215316
}))
216317
}
217318
}

0 commit comments

Comments
 (0)