diff --git a/rust/worker/src/execution/operators/mod.rs b/rust/worker/src/execution/operators/mod.rs index dd325345b4d..56162c77be5 100644 --- a/rust/worker/src/execution/operators/mod.rs +++ b/rust/worker/src/execution/operators/mod.rs @@ -27,3 +27,4 @@ pub mod select; pub mod source_record_segment; pub mod sparse_index_knn; pub mod sparse_log_knn; +pub mod transform_log; diff --git a/rust/worker/src/execution/operators/transform_log.rs b/rust/worker/src/execution/operators/transform_log.rs new file mode 100644 index 00000000000..e36361ca0a0 --- /dev/null +++ b/rust/worker/src/execution/operators/transform_log.rs @@ -0,0 +1,60 @@ +use async_trait::async_trait; +use chroma_error::{ChromaError, ErrorCodes}; +use chroma_system::Operator; +use chroma_types::{Chunk, LogRecord}; +use thiserror::Error; + +#[derive(Debug)] +pub struct TransformOperator {} + +#[derive(Debug)] +pub struct TransformInput { + pub(crate) records: Chunk, +} + +impl TransformInput { + pub fn new(records: Chunk) -> Self { + TransformInput { records } + } +} + +#[derive(Debug)] +pub struct TransformOutput { + pub(crate) records: Chunk, +} + +#[derive(Debug, Error)] +#[error("Failed to transform records.")] +pub struct TransformError; + +impl ChromaError for TransformError { + fn code(&self) -> ErrorCodes { + ErrorCodes::Internal + } +} + +impl TransformOperator { + pub fn new() -> Box { + Box::new(TransformOperator {}) + } + + pub fn transform(&self, records: &Chunk) -> Chunk { + records.clone() + } +} + +#[async_trait] +impl Operator for TransformOperator { + type Error = TransformError; + + fn get_name(&self) -> &'static str { + "TransformOperator" + } + + async fn run(&self, input: &TransformInput) -> Result { + let transformed_records = self.transform(&input.records); + Ok(TransformOutput { + records: transformed_records, + }) + } +} diff --git a/rust/worker/src/execution/orchestration/compact.rs b/rust/worker/src/execution/orchestration/compact.rs index bf6f78d19ed..dc4496f4f34 100644 --- a/rust/worker/src/execution/orchestration/compact.rs +++ b/rust/worker/src/execution/orchestration/compact.rs @@ -68,6 +68,7 @@ use crate::execution::operators::{ SourceRecordSegmentError, SourceRecordSegmentInput, SourceRecordSegmentOperator, SourceRecordSegmentOutput, }, + transform_log::{TransformError, TransformInput, TransformOperator, TransformOutput}, }; /** The state of the orchestrator. @@ -111,6 +112,7 @@ impl Default for CompactOrchestratorMetrics { #[derive(Debug)] enum ExecutionState { Pending, + Transform, Partition, MaterializeApplyCommitFlush, Register, @@ -141,6 +143,7 @@ pub struct CompactOrchestrator { hnsw_provider: HnswIndexProvider, spann_provider: SpannProvider, + // TODO(jobs): Split this into source and dest collections. collection: OnceCell, writers: OnceCell, flush_results: Vec, @@ -162,6 +165,9 @@ pub struct CompactOrchestrator { segment_spans: HashMap, metrics: CompactOrchestratorMetrics, + + // Functions + function: Option<()>, } #[derive(Error, Debug)] @@ -192,6 +198,8 @@ pub enum CompactionError { Panic(#[from] PanicError), #[error("Error partitioning logs: {0}")] Partition(#[from] PartitionError), + #[error("Error transforming logs: {0}")] + Transform(#[from] TransformError), #[error("Error prefetching segment: {0}")] PrefetchSegment(#[from] PrefetchSegmentError), #[error("Error creating record segment reader: {0}")] @@ -249,6 +257,7 @@ impl ChromaError for CompactionError { Self::MetadataSegment(e) => e.should_trace_error(), Self::Panic(e) => e.should_trace_error(), Self::Partition(e) => e.should_trace_error(), + Self::Transform(e) => e.should_trace_error(), Self::PrefetchSegment(e) => e.should_trace_error(), Self::RecordSegmentReader(e) => e.should_trace_error(), Self::RecordSegmentWriter(e) => e.should_trace_error(), @@ -276,6 +285,7 @@ pub enum CompactionResponse { impl CompactOrchestrator { #[allow(clippy::too_many_arguments)] pub fn new( + // TODO(jobs): Split this into source and dest collection IDs. collection_id: CollectionUuid, rebuild: bool, fetch_log_batch_size: u32, @@ -316,6 +326,7 @@ impl CompactOrchestrator { num_materialized_logs: 0, segment_spans: HashMap::new(), metrics: CompactOrchestratorMetrics::default(), + function: Some(()), } } @@ -325,6 +336,36 @@ impl CompactOrchestrator { } } + async fn transform_or_partition( + &mut self, + records: Chunk, + ctx: &ComponentContext, + ) { + if let Some(_fn) = self.function { + self.transform(records, ctx).await; + } else { + self.partition(records, ctx).await; + } + } + + async fn transform( + &mut self, + records: Chunk, + ctx: &ComponentContext, + ) { + self.state = ExecutionState::Transform; + let operator = TransformOperator::new(); + tracing::info!("Transforming {} records", records.len()); + let input = TransformInput::new(records); + let task = wrap( + operator, + input, + ctx.receiver(), + self.context.task_cancellation_token.clone(), + ); + self.send(task, ctx, Some(Span::current())).await; + } + async fn partition( &mut self, records: Chunk, @@ -974,7 +1015,7 @@ impl Handler> for CompactOrchestrator return; } } - self.partition(output, ctx).await; + self.transform_or_partition(output, ctx).await; } } @@ -1012,11 +1053,28 @@ impl Handler> ) .await; } else { - self.partition(output, ctx).await; + self.transform_or_partition(output, ctx).await; } } } +#[async_trait] +impl Handler> for CompactOrchestrator { + type Result = (); + + async fn handle( + &mut self, + message: TaskResult, + ctx: &ComponentContext, + ) { + let output = match self.ok_or_terminate(message.into_inner(), ctx).await { + Some(recs) => recs.records, + None => return, + }; + self.partition(output, ctx).await; + } +} + #[async_trait] impl Handler> for CompactOrchestrator { type Result = ();