Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,4 @@ owo-colors = "4.2.0"
json5 = "0.4.1"
aws-config = "1.6.2"
aws-sdk-s3 = "1.85.0"
aws-sdk-sqs = "1.67.0"
7 changes: 6 additions & 1 deletion examples/amazon_s3_text_embedding/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,9 @@ COCOINDEX_DATABASE_URL=postgres://cocoindex:cocoindex@localhost/cocoindex

# Amazon S3 Configuration
AMAZON_S3_BUCKET_NAME=your-bucket-name
AMAZON_S3_PREFIX=optional/prefix/path

# Optional
# AMAZON_S3_PREFIX=

# Optional
# AMAZON_S3_SQS_QUEUE_URL=
4 changes: 3 additions & 1 deletion examples/amazon_s3_text_embedding/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@ def amazon_s3_text_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scop
"""
bucket_name = os.environ["AMAZON_S3_BUCKET_NAME"]
prefix = os.environ.get("AMAZON_S3_PREFIX", None)
sqs_queue_url = os.environ.get("AMAZON_S3_SQS_QUEUE_URL", None)

data_scope["documents"] = flow_builder.add_source(
cocoindex.sources.AmazonS3(
bucket_name=bucket_name,
prefix=prefix,
included_patterns=["*.md", "*.txt", "*.docx"],
binary=False),
binary=False,
sqs_queue_url=sqs_queue_url),
refresh_interval=datetime.timedelta(minutes=1))

doc_embeddings = data_scope.add_collector()
Expand Down
1 change: 1 addition & 0 deletions python/cocoindex/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,4 @@ class AmazonS3(op.SourceSpec):
binary: bool = False
included_patterns: list[str] | None = None
excluded_patterns: list[str] | None = None
sqs_queue_url: str | None = None
141 changes: 121 additions & 20 deletions src/ops/sources/amazon_s3.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use crate::fields_value;
use async_stream::try_stream;
use aws_config::meta::region::RegionProviderChain;
use aws_config::Region;
use aws_config::BehaviorVersion;
use aws_sdk_s3::Client;
use globset::{Glob, GlobSet, GlobSetBuilder};
use log::warn;
Expand All @@ -17,15 +16,21 @@ pub struct Spec {
binary: bool,
included_patterns: Option<Vec<String>>,
excluded_patterns: Option<Vec<String>>,
sqs_queue_url: Option<String>,
}

struct SqsContext {
client: aws_sdk_sqs::Client,
queue_url: String,
}
struct Executor {
client: Client,
bucket_name: String,
prefix: Option<String>,
binary: bool,
included_glob_set: Option<GlobSet>,
excluded_glob_set: Option<GlobSet>,
sqs_context: Option<Arc<SqsContext>>,
}

impl Executor {
Expand Down Expand Up @@ -53,18 +58,13 @@ impl SourceExecutor for Executor {
&'a self,
_options: &'a SourceExecutorListOptions,
) -> BoxStream<'a, Result<Vec<SourceRowMetadata>>> {
let client = &self.client;
let bucket = &self.bucket_name;
let prefix = &self.prefix;
let included_glob_set = &self.included_glob_set;
let excluded_glob_set = &self.excluded_glob_set;
try_stream! {
let mut continuation_token = None;
loop {
let mut req = client
let mut req = self.client
.list_objects_v2()
.bucket(bucket);
if let Some(ref p) = prefix {
.bucket(&self.bucket_name);
if let Some(ref p) = self.prefix {
req = req.prefix(p);
}
if let Some(ref token) = continuation_token {
Expand All @@ -77,11 +77,11 @@ impl SourceExecutor for Executor {
if let Some(key) = obj.key() {
// Only include files (not folders)
if key.ends_with('/') { continue; }
let include = included_glob_set
let include = self.included_glob_set
.as_ref()
.map(|gs| gs.is_match(key))
.unwrap_or(true);
let exclude = excluded_glob_set
let exclude = self.excluded_glob_set
.as_ref()
.map(|gs| gs.is_match(key))
.unwrap_or(false);
Expand Down Expand Up @@ -152,6 +152,107 @@ impl SourceExecutor for Executor {
};
Ok(Some(SourceValue { value, ordinal }))
}

async fn change_stream(&self) -> Result<Option<BoxStream<'async_trait, SourceChange>>> {
let sqs_context = if let Some(sqs_context) = &self.sqs_context {
sqs_context
} else {
return Ok(None);
};
let stream = stream! {
loop {
let changes = match self.poll_sqs(&sqs_context).await {
Ok(changes) => changes,
Err(e) => {
warn!("Failed to poll SQS: {}", e);
continue;
}
};
for change in changes {
yield change;
}
}
};
Ok(Some(stream.boxed()))
}
}

#[derive(Debug, Deserialize)]
pub struct S3EventNotification {
#[serde(rename = "Records")]
pub records: Vec<S3EventRecord>,
}

#[derive(Debug, Deserialize)]
pub struct S3EventRecord {
#[serde(rename = "eventName")]
pub event_name: String,
// pub eventTime: String,
pub s3: S3Entity,
}

#[derive(Debug, Deserialize)]
pub struct S3Entity {
pub bucket: S3Bucket,
pub object: S3Object,
}

#[derive(Debug, Deserialize)]
pub struct S3Bucket {
pub name: String,
}

#[derive(Debug, Deserialize)]
pub struct S3Object {
pub key: String,
}

impl Executor {
async fn poll_sqs(&self, sqs_context: &Arc<SqsContext>) -> Result<Vec<SourceChange>> {
let resp = sqs_context
.client
.receive_message()
.queue_url(&sqs_context.queue_url)
.max_number_of_messages(10)
.wait_time_seconds(20)
.send()
.await?;
let messages = if let Some(messages) = resp.messages {
messages
} else {
return Ok(Vec::new());
};
let mut changes = vec![];
for message in messages.into_iter().filter_map(|m| m.body) {
let notification: S3EventNotification = serde_json::from_str(&message)?;
for record in notification.records {
if record.s3.bucket.name != self.bucket_name {
continue;
}
if !self
.prefix
.as_ref()
.map_or(true, |prefix| record.s3.object.key.starts_with(prefix))
{
continue;
}
if record.event_name.starts_with("ObjectCreated:") {
changes.push(SourceChange {
key: KeyValue::Str(record.s3.object.key.into()),
ordinal: None,
value: SourceValueChange::Upsert(None),
});
} else if record.event_name.starts_with("ObjectDeleted:") {
changes.push(SourceChange {
key: KeyValue::Str(record.s3.object.key.into()),
ordinal: None,
value: SourceValueChange::Delete,
});
}
}
}
Ok(changes)
}
}

pub struct Factory;
Expand Down Expand Up @@ -198,20 +299,20 @@ impl SourceFactoryBase for Factory {
spec: Spec,
_context: Arc<FlowInstanceContext>,
) -> Result<Box<dyn SourceExecutor>> {
let region_provider =
RegionProviderChain::default_provider().or_else(Region::new("us-east-1"));
let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
.region(region_provider)
.load()
.await;
let client = Client::new(&config);
let config = aws_config::load_defaults(BehaviorVersion::latest()).await;
Ok(Box::new(Executor {
client,
client: Client::new(&config),
bucket_name: spec.bucket_name,
prefix: spec.prefix,
binary: spec.binary,
included_glob_set: spec.included_patterns.map(build_glob_set).transpose()?,
excluded_glob_set: spec.excluded_patterns.map(build_glob_set).transpose()?,
sqs_context: spec.sqs_queue_url.map(|url| {
Arc::new(SqsContext {
client: aws_sdk_sqs::Client::new(&config),
queue_url: url,
})
}),
}))
}
}
Expand Down