diff --git a/Cargo.lock b/Cargo.lock index f7802002a29..ef35ed97748 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1992,6 +1992,7 @@ version = "0.1.0" dependencies = [ "async-trait", "chroma-error", + "chrono", "proptest", "proptest-derive 0.5.1", "prost 0.13.5", @@ -7140,6 +7141,7 @@ dependencies = [ "bytes", "chroma-error", "chroma-storage", + "chroma-sysdb", "chrono", "futures", "parking_lot", @@ -7160,8 +7162,10 @@ dependencies = [ "chroma-config", "chroma-error", "chroma-storage", + "chroma-sysdb", "chroma-tracing", "chroma-types", + "chrono", "figment", "futures", "s3heap", diff --git a/go/pkg/sysdb/coordinator/task.go b/go/pkg/sysdb/coordinator/task.go index 731c777eba0..35a8205434f 100644 --- a/go/pkg/sysdb/coordinator/task.go +++ b/go/pkg/sysdb/coordinator/task.go @@ -265,6 +265,39 @@ func (s *Coordinator) DeleteTask(ctx context.Context, req *coordinatorpb.DeleteT }, nil } +// Mark a task run as complete and set the nonce for the next task run. +func (s *Coordinator) AdvanceTask(ctx context.Context, req *coordinatorpb.AdvanceTaskRequest) (*coordinatorpb.AdvanceTaskResponse, error) { + if req.TaskId == nil { + log.Error("AdvanceTask: task_id is required") + return nil, status.Errorf(codes.InvalidArgument, "task_id is required") + } + + if req.TaskRunNonce == nil { + log.Error("AdvanceTask: task_run_nonce is required") + return nil, status.Errorf(codes.InvalidArgument, "task_run_nonce is required") + } + + taskID, err := uuid.Parse(*req.TaskId) + if err != nil { + log.Error("AdvanceTask: invalid task_id", zap.Error(err)) + return nil, status.Errorf(codes.InvalidArgument, "invalid task_id: %v", err) + } + + taskRunNonce, err := uuid.Parse(*req.TaskRunNonce) + if err != nil { + log.Error("AdvanceTask: invalid task_run_nonce", zap.Error(err)) + return nil, status.Errorf(codes.InvalidArgument, "invalid task_run_nonce: %v", err) + } + + err = s.catalog.metaDomain.TaskDb(ctx).AdvanceTask(taskID, taskRunNonce) + if err != nil { + log.Error("AdvanceTask failed", zap.Error(err), zap.String("task_id", taskID.String())) + return nil, err + } + + return &coordinatorpb.AdvanceTaskResponse{}, nil +} + // GetOperators retrieves all operators from the database func (s *Coordinator) GetOperators(ctx context.Context, req *coordinatorpb.GetOperatorsRequest) (*coordinatorpb.GetOperatorsResponse, error) { operators, err := s.catalog.metaDomain.OperatorDb(ctx).GetAll() @@ -288,3 +321,33 @@ func (s *Coordinator) GetOperators(ctx context.Context, req *coordinatorpb.GetOp Operators: protoOperators, }, nil } + +// PeekScheduleByCollectionId gives, for a vector of collection IDs, a vector of schedule entries, +// including when to run and the nonce to use for said run. +func (s *Coordinator) PeekScheduleByCollectionId(ctx context.Context, req *coordinatorpb.PeekScheduleByCollectionIdRequest) (*coordinatorpb.PeekScheduleByCollectionIdResponse, error) { + tasks, err := s.catalog.metaDomain.TaskDb(ctx).PeekScheduleByCollectionId(req.CollectionId) + if err != nil { + log.Error("PeekScheduleByCollectionId failed", zap.Error(err)) + return nil, err + } + + scheduleEntries := make([]*coordinatorpb.ScheduleEntry, 0, len(tasks)) + for _, task := range tasks { + task_id := task.ID.String() + entry := &coordinatorpb.ScheduleEntry{ + CollectionId: &task.InputCollectionID, + TaskId: &task_id, + TaskRunNonce: proto.String(task.NextNonce.String()), + WhenToRun: nil, + } + if task.NextRun != nil { + whenToRun := uint64(task.NextRun.UnixMilli()) + entry.WhenToRun = &whenToRun + } + scheduleEntries = append(scheduleEntries, entry) + } + + return &coordinatorpb.PeekScheduleByCollectionIdResponse{ + Schedule: scheduleEntries, + }, nil +} diff --git a/go/pkg/sysdb/grpc/task_service.go b/go/pkg/sysdb/grpc/task_service.go index baf6ba1826f..0b2415493df 100644 --- a/go/pkg/sysdb/grpc/task_service.go +++ b/go/pkg/sysdb/grpc/task_service.go @@ -52,6 +52,18 @@ func (s *Server) DeleteTask(ctx context.Context, req *coordinatorpb.DeleteTaskRe return res, nil } +func (s *Server) AdvanceTask(ctx context.Context, req *coordinatorpb.AdvanceTaskRequest) (*coordinatorpb.AdvanceTaskResponse, error) { + log.Info("AdvanceTask", zap.String("collection_id", req.GetCollectionId()), zap.String("task_id", req.GetTaskId())) + + res, err := s.coordinator.AdvanceTask(ctx, req) + if err != nil { + log.Error("AdvanceTask failed", zap.Error(err)) + return nil, err + } + + return res, nil +} + func (s *Server) GetOperators(ctx context.Context, req *coordinatorpb.GetOperatorsRequest) (*coordinatorpb.GetOperatorsResponse, error) { log.Info("GetOperators") @@ -63,3 +75,15 @@ func (s *Server) GetOperators(ctx context.Context, req *coordinatorpb.GetOperato return res, nil } + +func (s *Server) PeekScheduleByCollectionId(ctx context.Context, req *coordinatorpb.PeekScheduleByCollectionIdRequest) (*coordinatorpb.PeekScheduleByCollectionIdResponse, error) { + log.Info("PeekScheduleByCollectionId", zap.Int64("num_collections", int64(len(req.CollectionId)))) + + res, err := s.coordinator.PeekScheduleByCollectionId(ctx, req) + if err != nil { + log.Error("PeekScheduleByCollectionId failed", zap.Error(err)) + return nil, err + } + + return res, nil +} diff --git a/go/pkg/sysdb/metastore/db/dao/task.go b/go/pkg/sysdb/metastore/db/dao/task.go index f60c64b7c49..1c0b07d7ffe 100644 --- a/go/pkg/sysdb/metastore/db/dao/task.go +++ b/go/pkg/sysdb/metastore/db/dao/task.go @@ -2,9 +2,11 @@ package dao import ( "errors" + "time" "github.com/chroma-core/chroma/go/pkg/common" "github.com/chroma-core/chroma/go/pkg/sysdb/metastore/db/dbmodel" + "github.com/google/uuid" "github.com/jackc/pgx/v5/pgconn" "github.com/pingcap/log" "go.uber.org/zap" @@ -58,6 +60,55 @@ func (s *taskDb) GetByName(inputCollectionID string, taskName string) (*dbmodel. return &task, nil } +func (s *taskDb) GetByID(taskID uuid.UUID) (*dbmodel.Task, error) { + var task dbmodel.Task + err := s.db. + Where("task_id = ?", taskID). + Where("is_deleted = ?", false). + First(&task).Error + + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, nil + } + log.Error("GetByID failed", zap.Error(err), zap.String("task_id", taskID.String())) + return nil, err + } + return &task, nil +} + +func (s *taskDb) AdvanceTask(taskID uuid.UUID, taskRunNonce uuid.UUID) error { + nextNonce, err := uuid.NewV7() + if err != nil { + log.Error("AdvanceTask: failed to generate next nonce", zap.Error(err)) + return err + } + + now := time.Now() + result := s.db.Exec(` + UPDATE tasks + SET next_nonce = ?, + updated_at = GREATEST(updated_at, GREATEST(?, last_run)), + last_run = ?, + current_attempts = 0 + WHERE task_id = ? + AND next_nonce = ? + AND is_deleted = false + `, nextNonce, now, now, taskID, taskRunNonce) + + if result.Error != nil { + log.Error("AdvanceTask failed", zap.Error(result.Error), zap.String("task_id", taskID.String())) + return result.Error + } + + if result.RowsAffected == 0 { + log.Warn("AdvanceTask: no rows affected", zap.String("task_id", taskID.String()), zap.String("task_run_nonce", taskRunNonce.String())) + return common.ErrTaskNotFound + } + + return nil +} + func (s *taskDb) SoftDelete(inputCollectionID string, taskName string) error { // Update task name and is_deleted in a single query // Format: _deleted___ @@ -82,3 +133,17 @@ func (s *taskDb) SoftDelete(inputCollectionID string, taskName string) error { return nil } + +func (s *taskDb) PeekScheduleByCollectionId(collectionIDs []string) ([]*dbmodel.Task, error) { + var tasks []*dbmodel.Task + err := s.db. + Where("input_collection_id IN ?", collectionIDs). + Where("is_deleted = ?", false). + Find(&tasks).Error + + if err != nil { + log.Error("PeekScheduleByCollectionId failed", zap.Error(err)) + return nil, err + } + return tasks, nil +} diff --git a/go/pkg/sysdb/metastore/db/dao/task_test.go b/go/pkg/sysdb/metastore/db/dao/task_test.go index 94de1147608..3497c0df316 100644 --- a/go/pkg/sysdb/metastore/db/dao/task_test.go +++ b/go/pkg/sysdb/metastore/db/dao/task_test.go @@ -312,6 +312,144 @@ func (suite *TaskDbTestSuite) TestTaskDb_DeleteAll() { } } +func (suite *TaskDbTestSuite) TestTaskDb_GetByID() { + taskID := uuid.New() + operatorID := dbmodel.OperatorRecordCounter + nextNonce, _ := uuid.NewV7() + + task := &dbmodel.Task{ + ID: taskID, + Name: "test-get-by-id-task", + OperatorID: operatorID, + InputCollectionID: "input_col_id", + OutputCollectionName: "output_col_name", + OperatorParams: "{}", + TenantID: "tenant1", + DatabaseID: "db1", + MinRecordsForTask: 100, + NextNonce: nextNonce, + } + + err := suite.Db.Insert(task) + suite.Require().NoError(err) + + retrieved, err := suite.Db.GetByID(taskID) + suite.Require().NoError(err) + suite.Require().NotNil(retrieved) + suite.Require().Equal(task.ID, retrieved.ID) + suite.Require().Equal(task.Name, retrieved.Name) + suite.Require().Equal(task.OperatorID, retrieved.OperatorID) + + suite.db.Unscoped().Delete(&dbmodel.Task{}, "task_id = ?", task.ID) +} + +func (suite *TaskDbTestSuite) TestTaskDb_GetByID_NotFound() { + retrieved, err := suite.Db.GetByID(uuid.New()) + suite.Require().NoError(err) + suite.Require().Nil(retrieved) +} + +func (suite *TaskDbTestSuite) TestTaskDb_GetByID_IgnoresDeleted() { + taskID := uuid.New() + operatorID := dbmodel.OperatorRecordCounter + nextNonce, _ := uuid.NewV7() + + task := &dbmodel.Task{ + ID: taskID, + Name: "test-get-by-id-deleted", + OperatorID: operatorID, + InputCollectionID: "input1", + OutputCollectionName: "output1", + OperatorParams: "{}", + TenantID: "tenant1", + DatabaseID: "db1", + MinRecordsForTask: 100, + NextNonce: nextNonce, + } + + err := suite.Db.Insert(task) + suite.Require().NoError(err) + + err = suite.Db.SoftDelete("input1", "test-get-by-id-deleted") + suite.Require().NoError(err) + + retrieved, err := suite.Db.GetByID(taskID) + suite.Require().NoError(err) + suite.Require().Nil(retrieved) + + suite.db.Unscoped().Delete(&dbmodel.Task{}, "task_id = ?", task.ID) +} + +func (suite *TaskDbTestSuite) TestTaskDb_AdvanceTask() { + taskID := uuid.New() + operatorID := dbmodel.OperatorRecordCounter + originalNonce, _ := uuid.NewV7() + + task := &dbmodel.Task{ + ID: taskID, + Name: "test-advance-task", + OperatorID: operatorID, + InputCollectionID: "input_col_id", + OutputCollectionName: "output_col_name", + OperatorParams: "{}", + TenantID: "tenant1", + DatabaseID: "db1", + MinRecordsForTask: 100, + NextNonce: originalNonce, + CurrentAttempts: 3, + } + + err := suite.Db.Insert(task) + suite.Require().NoError(err) + + err = suite.Db.AdvanceTask(taskID, originalNonce) + suite.Require().NoError(err) + + retrieved, err := suite.Db.GetByID(taskID) + suite.Require().NoError(err) + suite.Require().NotNil(retrieved) + suite.Require().NotEqual(originalNonce, retrieved.NextNonce) + suite.Require().NotNil(retrieved.LastRun) + suite.Require().Equal(int32(0), retrieved.CurrentAttempts) + + suite.db.Unscoped().Delete(&dbmodel.Task{}, "task_id = ?", task.ID) +} + +func (suite *TaskDbTestSuite) TestTaskDb_AdvanceTask_InvalidNonce() { + taskID := uuid.New() + operatorID := dbmodel.OperatorRecordCounter + correctNonce, _ := uuid.NewV7() + wrongNonce, _ := uuid.NewV7() + + task := &dbmodel.Task{ + ID: taskID, + Name: "test-advance-task-wrong-nonce", + OperatorID: operatorID, + InputCollectionID: "input_col_id", + OutputCollectionName: "output_col_name", + OperatorParams: "{}", + TenantID: "tenant1", + DatabaseID: "db1", + MinRecordsForTask: 100, + NextNonce: correctNonce, + } + + err := suite.Db.Insert(task) + suite.Require().NoError(err) + + err = suite.Db.AdvanceTask(taskID, wrongNonce) + suite.Require().Error(err) + suite.Require().Equal(common.ErrTaskNotFound, err) + + suite.db.Unscoped().Delete(&dbmodel.Task{}, "task_id = ?", task.ID) +} + +func (suite *TaskDbTestSuite) TestTaskDb_AdvanceTask_NotFound() { + err := suite.Db.AdvanceTask(uuid.New(), uuid.Must(uuid.NewV7())) + suite.Require().Error(err) + suite.Require().Equal(common.ErrTaskNotFound, err) +} + // TestOperatorConstantsMatchSeededDatabase verifies that operator constants in // dbmodel/constants.go match what we seed in the test database (which should match migrations). // This catches drift between constants and migrations at test time. diff --git a/go/pkg/sysdb/metastore/db/dbmodel/task.go b/go/pkg/sysdb/metastore/db/dbmodel/task.go index 9dd79f9ce79..32935b67ecd 100644 --- a/go/pkg/sysdb/metastore/db/dbmodel/task.go +++ b/go/pkg/sysdb/metastore/db/dbmodel/task.go @@ -38,6 +38,9 @@ func (v Task) TableName() string { type ITaskDb interface { Insert(task *Task) error GetByName(inputCollectionID string, taskName string) (*Task, error) + GetByID(taskID uuid.UUID) (*Task, error) + AdvanceTask(taskID uuid.UUID, taskRunNonce uuid.UUID) error SoftDelete(inputCollectionID string, taskName string) error DeleteAll() error + PeekScheduleByCollectionId(collectionIDs []string) ([]*Task, error) } diff --git a/idl/chromadb/proto/coordinator.proto b/idl/chromadb/proto/coordinator.proto index e8dc47794fc..c8ee3653d47 100644 --- a/idl/chromadb/proto/coordinator.proto +++ b/idl/chromadb/proto/coordinator.proto @@ -570,6 +570,14 @@ message DeleteTaskResponse { bool success = 1; } +message AdvanceTaskRequest { + optional string collection_id = 1; + optional string task_id = 2; + optional string task_run_nonce = 3; +} + +message AdvanceTaskResponse {} + message Operator { string id = 1; string name = 2; @@ -583,6 +591,21 @@ message GetOperatorsResponse { repeated Operator operators = 1; } +message PeekScheduleByCollectionIdRequest { + repeated string collection_id = 1; +} + +message ScheduleEntry { + optional string collection_id = 1; + optional string task_id = 2; + optional string task_run_nonce = 3; + optional uint64 when_to_run = 4; +} + +message PeekScheduleByCollectionIdResponse { + repeated ScheduleEntry schedule = 1; +} + service SysDB { rpc CreateDatabase(CreateDatabaseRequest) returns (CreateDatabaseResponse) {} rpc GetDatabase(GetDatabaseRequest) returns (GetDatabaseResponse) {} @@ -623,5 +646,7 @@ service SysDB { rpc CreateTask(CreateTaskRequest) returns (CreateTaskResponse) {} rpc GetTaskByName(GetTaskByNameRequest) returns (GetTaskByNameResponse) {} rpc DeleteTask(DeleteTaskRequest) returns (DeleteTaskResponse) {} + rpc AdvanceTask(AdvanceTaskRequest) returns (AdvanceTaskResponse) {} rpc GetOperators(GetOperatorsRequest) returns (GetOperatorsResponse) {} + rpc PeekScheduleByCollectionId(PeekScheduleByCollectionIdRequest) returns (PeekScheduleByCollectionIdResponse) {} } diff --git a/rust/s3heap-service/Cargo.toml b/rust/s3heap-service/Cargo.toml index f9e69b94852..c2244a6d2ac 100644 --- a/rust/s3heap-service/Cargo.toml +++ b/rust/s3heap-service/Cargo.toml @@ -5,6 +5,7 @@ edition = "2021" [dependencies] async-trait = { workspace = true } +chrono = { workspace = true } figment = { workspace = true } futures = { workspace = true } serde = { workspace = true } @@ -17,6 +18,7 @@ tracing = { workspace = true } chroma-config = { workspace = true } chroma-error = { workspace = true } chroma-storage = { workspace = true } +chroma-sysdb = { workspace = true } chroma-tracing = { workspace = true, features = ["grpc"] } chroma-types = { workspace = true } s3heap = { workspace = true } diff --git a/rust/s3heap-service/src/lib.rs b/rust/s3heap-service/src/lib.rs index 52df6bedfb6..f0767f73ee2 100644 --- a/rust/s3heap-service/src/lib.rs +++ b/rust/s3heap-service/src/lib.rs @@ -12,21 +12,20 @@ use chroma_config::Configurable; use chroma_error::ChromaError; use chroma_storage::config::StorageConfig; use chroma_storage::Storage; +use chroma_sysdb::{SysDb, SysDbConfig}; use chroma_tracing::OtelFilter; use chroma_tracing::OtelFilterLevel; use chroma_types::chroma_proto::heap_tender_service_server::{ HeapTenderService, HeapTenderServiceServer, }; use chroma_types::chroma_proto::{HeapSummaryRequest, HeapSummaryResponse}; -use chroma_types::{dirty_log_path_from_hostname, CollectionUuid, DirtyMarker}; +use chroma_types::{dirty_log_path_from_hostname, CollectionUuid, DirtyMarker, ScheduleEntry}; +use s3heap::{Configuration, DummyScheduler, Error, HeapWriter, Schedule, Triggerable}; use wal3::{ Cursor, CursorName, CursorStore, CursorStoreOptions, LogPosition, LogReader, LogReaderOptions, Witness, }; -use s3heap::DummyScheduler; -use s3heap::{Configuration, Error, HeapWriter}; - ///////////////////////////////////////////// constants //////////////////////////////////////////// const DEFAULT_CONFIG_PATH: &str = "./chroma_config.yaml"; @@ -46,26 +45,58 @@ pub static HEAP_TENDER_CURSOR_NAME: CursorName = /// Manages heap compaction by reading dirty logs and coordinating with HeapWriter. pub struct HeapTender { + #[allow(dead_code)] + sysdb: SysDb, reader: LogReader, cursor: CursorStore, - _writer: HeapWriter, + writer: HeapWriter, } impl HeapTender { /// Creates a new HeapTender. - pub fn new(reader: LogReader, cursor: CursorStore, writer: HeapWriter) -> Self { + pub fn new(sysdb: SysDb, reader: LogReader, cursor: CursorStore, writer: HeapWriter) -> Self { Self { + sysdb, reader, cursor, - _writer: writer, + writer, } } /// Tends to the heap by reading and coalescing the dirty log, then updating the cursor. pub async fn tend_to_heap(&self) -> Result<(), Error> { let (witness, cursor, tended) = self.read_and_coalesce_dirty_log().await?; - // TODO(rescrv): Do something with tended and update the cursor iff tended is false. - _ = tended; + if !tended.is_empty() { + let collection_ids = tended.iter().map(|t| t.0).collect::>(); + let scheduled = self + .sysdb + .clone() + .peek_schedule_by_collection_id(&collection_ids) + .await?; + let triggerables: Vec> = scheduled + .into_iter() + .map(|s: ScheduleEntry| -> Result<_, Error> { + let triggerable = Triggerable { + partitioning: s3heap::UnitOfPartitioningUuid::new(s.collection_id.0), + scheduling: s3heap::UnitOfSchedulingUuid::new(s.task_id), + }; + if let Some(next_scheduled) = s.when_to_run { + let schedule = Schedule { + triggerable, + next_scheduled, + nonce: s.task_run_nonce, + }; + Ok(Some(schedule)) + } else { + Ok(None) + } + }) + .collect::, _>>()?; + let triggerables: Vec = triggerables.into_iter().flatten().collect(); + if !triggerables.is_empty() { + self.writer.push(&triggerables).await?; + } + } if let Some(witness) = witness.as_ref() { self.cursor .save(&HEAP_TENDER_CURSOR_NAME, &cursor, witness) @@ -243,6 +274,13 @@ impl Configurable for HeapTenderServer { config: &HeapTenderServerConfig, registry: &chroma_config::registry::Registry, ) -> Result> { + match &config.sysdb { + chroma_sysdb::SysDbConfig::Grpc(_) => {} + chroma_sysdb::SysDbConfig::Sqlite(_) => { + panic!("Expected grpc sysdb config, got sqlite sysdb config") + } + }; + let sysdb = SysDb::try_from_config(&config.sysdb, registry).await?; let storage = Storage::try_from_config(&config.storage, registry).await?; let dirty_log_prefix = dirty_log_path_from_hostname(&config.my_member_id); let reader = LogReader::new( @@ -258,12 +296,13 @@ impl Configurable for HeapTenderServer { ); let heap_prefix = heap_path_from_hostname(&config.my_member_id); let scheduler = Arc::new(DummyScheduler) as _; - let writer = HeapWriter::new(storage, heap_prefix, scheduler) + let writer = HeapWriter::new(storage, heap_prefix, Arc::clone(&scheduler)) .map_err(|e| -> Box { Box::new(e) })?; let tender = Arc::new(HeapTender { + sysdb, reader, cursor, - _writer: writer, + writer, }); Ok(Self { config: config.clone(), @@ -413,6 +452,9 @@ pub struct HeapTenderServerConfig { /// Optional OpenTelemetry configuration for tracing. #[serde(default)] pub opentelemetry: Option, + /// Configuration for the sysdb backend. + #[serde(default = "HeapTenderServerConfig::default_sysdb_config")] + pub sysdb: SysDbConfig, /// Configuration for the S3 storage backend. #[serde(default)] pub storage: StorageConfig, @@ -465,6 +507,10 @@ impl HeapTenderServerConfig { Duration::from_secs(10) } + fn default_sysdb_config() -> SysDbConfig { + SysDbConfig::Grpc(Default::default()) + } + fn default_grpc_shutdown_grace_period() -> Duration { Duration::from_secs(1) } @@ -476,6 +522,7 @@ impl Default for HeapTenderServerConfig { port: HeapTenderServerConfig::default_port(), my_member_id: HeapTenderServerConfig::default_my_member_id(), opentelemetry: None, + sysdb: HeapTenderServerConfig::default_sysdb_config(), storage: StorageConfig::default(), reader: LogReaderOptions::default(), cursor: CursorStoreOptions::default(), diff --git a/rust/s3heap-service/tests/test_k8s_integration_00_heap_tender.rs b/rust/s3heap-service/tests/test_k8s_integration_00_heap_tender.rs index 57ea3ca6a14..9c77e889d6f 100644 --- a/rust/s3heap-service/tests/test_k8s_integration_00_heap_tender.rs +++ b/rust/s3heap-service/tests/test_k8s_integration_00_heap_tender.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use chroma_storage::Storage; +use chroma_sysdb::{SysDb, TestSysDb}; use chroma_types::{CollectionUuid, DirtyMarker}; use wal3::{CursorStore, CursorStoreOptions, LogPosition, LogReader, LogReaderOptions}; @@ -14,6 +15,7 @@ fn test_heap_tender(storage: Storage, test_id: &str) -> HeapTender { } fn create_heap_tender(storage: Storage, dirty_log_prefix: &str, heap_prefix: &str) -> HeapTender { + let sysdb = SysDb::Test(TestSysDb::new()); let reader = LogReader::new( LogReaderOptions::default(), Arc::new(storage.clone()), @@ -26,8 +28,8 @@ fn create_heap_tender(storage: Storage, dirty_log_prefix: &str, heap_prefix: &st "test-tender".to_string(), ); let scheduler = Arc::new(DummyScheduler) as _; - let writer = HeapWriter::new(storage, heap_prefix.to_string(), scheduler).unwrap(); - HeapTender::new(reader, cursor, writer) + let writer = HeapWriter::new(storage, heap_prefix.to_string(), Arc::clone(&scheduler)).unwrap(); + HeapTender::new(sysdb, reader, cursor, writer) } #[tokio::test] diff --git a/rust/s3heap/Cargo.toml b/rust/s3heap/Cargo.toml index 9987390377c..6b0efc04400 100644 --- a/rust/s3heap/Cargo.toml +++ b/rust/s3heap/Cargo.toml @@ -18,6 +18,7 @@ uuid = { workspace = true } chroma-error = { workspace = true } chroma-storage = { workspace = true } +chroma-sysdb = { workspace = true } wal3 = { workspace = true } [dev-dependencies] diff --git a/rust/s3heap/src/internal.rs b/rust/s3heap/src/internal.rs index 684498de97a..a8a7f82b965 100644 --- a/rust/s3heap/src/internal.rs +++ b/rust/s3heap/src/internal.rs @@ -67,7 +67,7 @@ fn get_string_column<'a>( /// nonce: Uuid::new_v4(), /// }; /// ``` -#[derive(Clone, Debug, Default, PartialEq, Eq)] +#[derive(Clone, Debug, Default, PartialEq, Eq, Ord, PartialOrd)] pub struct HeapItem { /// The triggerable task to be executed pub trigger: Triggerable, @@ -154,9 +154,24 @@ impl Internal { let entries = entries.to_vec(); (|| async { - let (mut on_s3, e_tag) = self.load_bucket_or_empty(bucket).await?; - on_s3.extend(entries.iter().cloned()); - self.store_bucket(bucket, &on_s3, e_tag).await + let (on_s3, e_tag) = self.load_bucket_or_empty(bucket).await?; + let triggerables = on_s3 + .iter() + .map(|x| (x.trigger, x.nonce)) + .collect::>(); + let schedules = self.heap_scheduler.are_done(&triggerables).await?; + let mut results = Vec::with_capacity(on_s3.len().saturating_add(entries.len())); + for (item, is_done) in on_s3.into_iter().zip(schedules) { + if !is_done { + results.push(item) + } + } + results.extend(entries.clone()); + results.sort(); + results.reverse(); + results.dedup_by_key(|x| x.trigger); + results.reverse(); + self.store_bucket(bucket, &results, e_tag).await }) .retry(backoff) .await @@ -562,15 +577,9 @@ mod tests { }; let nonce = Uuid::new_v4(); - let item1 = HeapItem { - trigger: trigger.clone(), - nonce, - }; + let item1 = HeapItem { trigger, nonce }; - let item2 = HeapItem { - trigger: trigger.clone(), - nonce, - }; + let item2 = HeapItem { trigger, nonce }; assert_eq!(item1, item2); assert_eq!(item1.trigger, trigger); diff --git a/rust/s3heap/src/lib.rs b/rust/s3heap/src/lib.rs index 403f089ef34..7ca20badb6b 100644 --- a/rust/s3heap/src/lib.rs +++ b/rust/s3heap/src/lib.rs @@ -142,6 +142,9 @@ pub enum Error { /// Date rounding error #[error("could not round date: {0}")] RoundError(#[from] chrono::RoundingError), + /// SysDb error + #[error("sysdb error: {0}")] + SysDb(#[from] chroma_sysdb::PeekScheduleError), } impl chroma_error::ChromaError for Error { @@ -162,6 +165,7 @@ impl chroma_error::ChromaError for Error { Error::Arrow(_) => ErrorCodes::Internal, Error::ParseDate(_) => ErrorCodes::InvalidArgument, Error::RoundError(_) => ErrorCodes::Internal, + Error::SysDb(e) => e.code(), } } } @@ -362,7 +366,7 @@ impl Limits { /// The UnitOfPartitioning is e.g. a Chroma collection or some other unit of work that is a /// functional dependency of the key used for partitioning. Always a UUID. -#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] +#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Ord, PartialOrd, Hash)] pub struct UnitOfPartitioningUuid(Uuid); impl UnitOfPartitioningUuid { @@ -391,7 +395,7 @@ impl fmt::Display for UnitOfPartitioningUuid { /// The UnitOfScheduling is the identifier for the individual thing to push and pop off the heap. A /// given UnitOfPartitioning may have many UnitOfScheduling UUIDs assigned to it. -#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] +#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Ord, PartialOrd, Hash)] pub struct UnitOfSchedulingUuid(Uuid); impl UnitOfSchedulingUuid { @@ -440,7 +444,7 @@ impl fmt::Display for UnitOfSchedulingUuid { /// scheduling: UnitOfSchedulingUuid::new(Uuid::new_v4()), /// }; /// ``` -#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] +#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Ord, PartialOrd, Hash)] pub struct Triggerable { /// The UUID identifying the partitioning unit pub partitioning: UnitOfPartitioningUuid, @@ -493,7 +497,6 @@ pub struct Schedule { /// #[async_trait::async_trait] /// impl HeapScheduler for MyScheduler { /// async fn are_done(&self, items: &[(Triggerable, Uuid)]) -> Result, Error> { -/// // Check if tasks are complete in your system /// let completed = self.completed_tasks.lock(); /// Ok(items.iter() /// .map(|(item, nonce)| completed.get(&(*item.partitioning.as_uuid(), *item.scheduling.as_uuid(), *nonce)).copied().unwrap_or(false)) @@ -525,7 +528,7 @@ pub trait HeapScheduler: Send + Sync { /// * `Ok(false)` if the task is still pending or running /// * `Err` if there was an error checking the status async fn is_done(&self, item: &Triggerable, nonce: Uuid) -> Result { - let results = self.are_done(&[(item.clone(), nonce)]).await?; + let results = self.are_done(&[(*item, nonce)]).await?; if results.len() != 1 { return Err(Error::Internal(format!( "are_done returned {} results for 1 item", @@ -733,7 +736,7 @@ impl HeapWriter { for schedule in schedules { let heap_item = HeapItem { - trigger: schedule.triggerable.clone(), + trigger: schedule.triggerable, nonce: schedule.nonce, }; let bucket = self.internal.compute_bucket(schedule.next_scheduled)?; @@ -950,7 +953,7 @@ impl HeapPruner { let original_count = entries.len(); let triggers = entries .iter() - .map(|e| (e.trigger.clone(), e.nonce)) + .map(|e| (e.trigger, e.nonce)) .collect::>(); let are_done = heap_scheduler.are_done(&triggers).await?; @@ -1146,7 +1149,7 @@ impl HeapReader { let triggerable_and_nonce = entries .iter() .filter(|hi| should_return(&hi.trigger)) - .map(|hi| (hi.trigger.clone(), hi.nonce)) + .map(|hi| (hi.trigger, hi.nonce)) .collect::>(); let are_done = heap_scheduler.are_done(&triggerable_and_nonce).await?; if triggerable_and_nonce.len() != are_done.len() { @@ -1159,7 +1162,7 @@ impl HeapReader { for ((triggerable, uuid), is_done) in triggerable_and_nonce.iter().zip(are_done) { if !is_done { returns.push(HeapItem { - trigger: triggerable.clone(), + trigger: *triggerable, nonce: *uuid, }); if returns.len() >= max_items { diff --git a/rust/s3heap/tests/common.rs b/rust/s3heap/tests/common.rs index a5ce6286f2b..25e8c00cf48 100644 --- a/rust/s3heap/tests/common.rs +++ b/rust/s3heap/tests/common.rs @@ -271,14 +271,16 @@ impl<'a> TestItemBuilder<'a> { let time = test_time_at_minute_offset(base, self.time_offset_minutes); let schedule = Schedule { - triggerable: item.clone(), + triggerable: item, next_scheduled: time, nonce, }; - self.scheduler - .set_schedule(*item.scheduling.as_uuid(), Some(schedule.clone())); + self.scheduler.set_schedule( + *schedule.triggerable.scheduling.as_uuid(), + Some(schedule.clone()), + ); if let Some(done) = self.is_done { - self.scheduler.set_done(&item, nonce, done); + self.scheduler.set_done(&schedule.triggerable, nonce, done); } schedule } diff --git a/rust/s3heap/tests/test_k8s_integration_03_merge_buckets.rs b/rust/s3heap/tests/test_k8s_integration_03_merge_buckets.rs index 3fcbcf566e7..60e95d71dc4 100644 --- a/rust/s3heap/tests/test_k8s_integration_03_merge_buckets.rs +++ b/rust/s3heap/tests/test_k8s_integration_03_merge_buckets.rs @@ -22,17 +22,17 @@ async fn test_k8s_integration_03_merge_same_bucket() { let now = Utc::now().duration_trunc(TimeDelta::minutes(1)).unwrap(); let base_time = test_time_at_minute_offset(now, 5); let schedule1 = Schedule { - triggerable: item1.clone(), + triggerable: item1, next_scheduled: base_time, nonce: test_nonce(1), }; let schedule2 = Schedule { - triggerable: item2.clone(), + triggerable: item2, next_scheduled: base_time + Duration::seconds(10), nonce: test_nonce(2), }; let schedule3 = Schedule { - triggerable: item3.clone(), + triggerable: item3, next_scheduled: base_time + Duration::seconds(30), nonce: test_nonce(3), }; @@ -92,12 +92,12 @@ async fn test_k8s_integration_03_merge_multiple_pushes() { let push_time = test_time_at_minute_offset(now, 10); let schedule1 = Schedule { - triggerable: item1.clone(), + triggerable: item1, next_scheduled: push_time, nonce: test_nonce(1), }; let schedule2 = Schedule { - triggerable: item2.clone(), + triggerable: item2, next_scheduled: push_time + Duration::seconds(5), nonce: test_nonce(2), }; @@ -114,12 +114,12 @@ async fn test_k8s_integration_03_merge_multiple_pushes() { let item4 = create_test_triggerable(4, 4); let schedule3 = Schedule { - triggerable: item3.clone(), + triggerable: item3, next_scheduled: push_time + Duration::seconds(20), nonce: test_nonce(3), }; let schedule4 = Schedule { - triggerable: item4.clone(), + triggerable: item4, next_scheduled: push_time + Duration::seconds(40), nonce: test_nonce(4), }; diff --git a/rust/s3heap/tests/test_k8s_integration_05_peek_filtering.rs b/rust/s3heap/tests/test_k8s_integration_05_peek_filtering.rs index 9ce4cf35c83..bc5f03027ff 100644 --- a/rust/s3heap/tests/test_k8s_integration_05_peek_filtering.rs +++ b/rust/s3heap/tests/test_k8s_integration_05_peek_filtering.rs @@ -25,27 +25,27 @@ async fn test_k8s_integration_05_peek_all_items() { let now = Utc::now(); let time = test_time_at_minute_offset(now, 5); let schedule1 = Schedule { - triggerable: item1.clone(), + triggerable: item1, next_scheduled: time, nonce: test_nonce(1), }; let schedule2 = Schedule { - triggerable: item2.clone(), + triggerable: item2, next_scheduled: time, nonce: test_nonce(2), }; let schedule3 = Schedule { - triggerable: item3.clone(), + triggerable: item3, next_scheduled: time, nonce: test_nonce(3), }; let schedule4 = Schedule { - triggerable: item4.clone(), + triggerable: item4, next_scheduled: time, nonce: test_nonce(4), }; let schedule5 = Schedule { - triggerable: item5.clone(), + triggerable: item5, next_scheduled: time, nonce: test_nonce(5), }; @@ -102,27 +102,27 @@ async fn test_k8s_integration_05_peek_with_filter() { let now = Utc::now(); let time = test_time_at_minute_offset(now, 5); let schedule1 = Schedule { - triggerable: item1.clone(), + triggerable: item1, next_scheduled: time, nonce: test_nonce(1), }; let schedule2 = Schedule { - triggerable: item2.clone(), + triggerable: item2, next_scheduled: time, nonce: test_nonce(2), }; let schedule3 = Schedule { - triggerable: item3.clone(), + triggerable: item3, next_scheduled: time, nonce: test_nonce(3), }; let schedule4 = Schedule { - triggerable: item4.clone(), + triggerable: item4, next_scheduled: time, nonce: test_nonce(4), }; let schedule5 = Schedule { - triggerable: item5.clone(), + triggerable: item5, next_scheduled: time, nonce: test_nonce(5), }; @@ -209,17 +209,17 @@ async fn test_k8s_integration_05_peek_filters_completed() { let nonce3 = test_nonce(3); let schedule1 = Schedule { - triggerable: item1.clone(), + triggerable: item1, next_scheduled: time, nonce: nonce1, }; let schedule2 = Schedule { - triggerable: item2.clone(), + triggerable: item2, next_scheduled: time, nonce: nonce2, }; let schedule3 = Schedule { - triggerable: item3.clone(), + triggerable: item3, next_scheduled: time, nonce: nonce3, }; @@ -278,22 +278,22 @@ async fn test_k8s_integration_05_peek_across_buckets() { let time2 = test_time_at_minute_offset(now, 10); let schedule1 = Schedule { - triggerable: item1.clone(), + triggerable: item1, next_scheduled: time1, nonce: test_nonce(1), }; let schedule2 = Schedule { - triggerable: item2.clone(), + triggerable: item2, next_scheduled: time1, nonce: test_nonce(2), }; let schedule3 = Schedule { - triggerable: item3.clone(), + triggerable: item3, next_scheduled: time2, nonce: test_nonce(3), }; let schedule4 = Schedule { - triggerable: item4.clone(), + triggerable: item4, next_scheduled: time2, nonce: test_nonce(4), }; diff --git a/rust/s3heap/tests/test_k8s_integration_06_retry_logic.rs b/rust/s3heap/tests/test_k8s_integration_06_retry_logic.rs index 1f764709a88..7db11f0c56f 100644 --- a/rust/s3heap/tests/test_k8s_integration_06_retry_logic.rs +++ b/rust/s3heap/tests/test_k8s_integration_06_retry_logic.rs @@ -40,12 +40,12 @@ async fn test_k8s_integration_06_concurrent_writes_with_retry() { let item2 = create_test_triggerable(2, 2); let schedule1 = Schedule { - triggerable: item1.clone(), + triggerable: item1, next_scheduled: time, nonce: test_nonce(1), }; let schedule2 = Schedule { - triggerable: item2.clone(), + triggerable: item2, next_scheduled: time, nonce: test_nonce(2), }; @@ -73,7 +73,7 @@ async fn test_k8s_integration_06_prune_with_retry() { let nonce = test_nonce(1); let now = Utc::now(); let schedule = Schedule { - triggerable: item.clone(), + triggerable: item, next_scheduled: test_time_at_minute_offset(now, 3), nonce, }; diff --git a/rust/s3heap/tests/test_k8s_integration_07_bucket_computation.rs b/rust/s3heap/tests/test_k8s_integration_07_bucket_computation.rs index 900872e4965..b4301f60d74 100644 --- a/rust/s3heap/tests/test_k8s_integration_07_bucket_computation.rs +++ b/rust/s3heap/tests/test_k8s_integration_07_bucket_computation.rs @@ -25,22 +25,22 @@ async fn test_k8s_integration_07_bucket_rounding() { let item4 = create_test_triggerable(4, 4); let schedule1 = Schedule { - triggerable: item1.clone(), + triggerable: item1, next_scheduled: base_time, nonce: test_nonce(1), }; let schedule2 = Schedule { - triggerable: item2.clone(), + triggerable: item2, next_scheduled: base_time + Duration::seconds(15), nonce: test_nonce(2), }; let schedule3 = Schedule { - triggerable: item3.clone(), + triggerable: item3, next_scheduled: base_time + Duration::seconds(30), nonce: test_nonce(3), }; let schedule4 = Schedule { - triggerable: item4.clone(), + triggerable: item4, next_scheduled: base_time + Duration::seconds(59), nonce: test_nonce(4), }; @@ -88,12 +88,12 @@ async fn test_k8s_integration_07_bucket_boundaries() { let item2 = create_test_triggerable(2, 2); let schedule1 = Schedule { - triggerable: item1.clone(), + triggerable: item1, next_scheduled: minute1 + Duration::seconds(59), nonce: test_nonce(1), }; let schedule2 = Schedule { - triggerable: item2.clone(), + triggerable: item2, next_scheduled: minute2, nonce: test_nonce(2), }; @@ -134,7 +134,7 @@ async fn test_k8s_integration_07_bucket_path_format() { .with_timezone(&Utc); let schedule = Schedule { - triggerable: item.clone(), + triggerable: item, next_scheduled: scheduled_time, nonce: test_nonce(1), }; @@ -180,7 +180,7 @@ async fn test_k8s_integration_07_multiple_buckets_ordering() { let item = create_test_triggerable(i as u32, i as u32); let time = base_time + Duration::minutes(i * 5); let schedule = Schedule { - triggerable: item.clone(), + triggerable: item, next_scheduled: time, nonce: test_nonce(i as u32), }; diff --git a/rust/s3heap/tests/test_k8s_integration_08_concurrent_operations.rs b/rust/s3heap/tests/test_k8s_integration_08_concurrent_operations.rs index cc34712ca5f..f512cd82a54 100644 --- a/rust/s3heap/tests/test_k8s_integration_08_concurrent_operations.rs +++ b/rust/s3heap/tests/test_k8s_integration_08_concurrent_operations.rs @@ -26,7 +26,7 @@ async fn test_k8s_integration_08_concurrent_pushes() { scheduler.set_schedule( *item.scheduling.as_uuid(), Some(Schedule { - triggerable: item.clone(), + triggerable: item, next_scheduled: bucket_time, nonce: test_nonce(i), }), @@ -91,7 +91,7 @@ async fn test_k8s_integration_08_concurrent_read_write() { .map(|i| { let item = create_test_triggerable(i, i); let schedule = Schedule { - triggerable: item.clone(), + triggerable: item, next_scheduled: bucket_time, nonce: test_nonce(i), }; @@ -128,7 +128,7 @@ async fn test_k8s_integration_08_concurrent_read_write() { let idx = 100 + batch * 5 + i; let item = create_test_triggerable(idx, idx); let schedule = Schedule { - triggerable: item.clone(), + triggerable: item, next_scheduled: bucket_time, nonce: test_nonce(idx), }; @@ -196,7 +196,7 @@ async fn test_k8s_integration_08_concurrent_prune_push() { let item = create_test_triggerable(i, i); let nonce = test_nonce(i); let schedule = Schedule { - triggerable: item.clone(), + triggerable: item, next_scheduled: bucket_time, nonce, }; @@ -240,7 +240,7 @@ async fn test_k8s_integration_08_concurrent_prune_push() { .map(|i| { let item = create_test_triggerable(i, i); let schedule = Schedule { - triggerable: item.clone(), + triggerable: item, next_scheduled: bucket_time, nonce: test_nonce(i), }; diff --git a/rust/s3heap/tests/test_unit_tests.rs b/rust/s3heap/tests/test_unit_tests.rs index 57cee8ed453..e0f2cd041c8 100644 --- a/rust/s3heap/tests/test_unit_tests.rs +++ b/rust/s3heap/tests/test_unit_tests.rs @@ -228,7 +228,7 @@ fn triggerable_clone() { partitioning: Uuid::new_v4().into(), scheduling: Uuid::new_v4().into(), }; - let cloned = original.clone(); + let cloned = original; assert_eq!(original, cloned); assert_eq!(original.partitioning, cloned.partitioning); assert_eq!(original.scheduling, cloned.scheduling); diff --git a/rust/sysdb/src/sysdb.rs b/rust/sysdb/src/sysdb.rs index ee2ef67141d..527d8db3617 100644 --- a/rust/sysdb/src/sysdb.rs +++ b/rust/sysdb/src/sysdb.rs @@ -16,8 +16,9 @@ use chroma_types::{ GetDatabaseResponse, GetSegmentsError, GetTenantError, GetTenantResponse, InternalCollectionConfiguration, InternalUpdateCollectionConfiguration, ListCollectionVersionsError, ListDatabasesError, ListDatabasesResponse, Metadata, ResetError, - ResetResponse, SegmentFlushInfo, SegmentFlushInfoConversionError, SegmentUuid, - UpdateCollectionError, UpdateTenantError, UpdateTenantResponse, VectorIndexConfiguration, + ResetResponse, ScheduleEntry, ScheduleEntryConversionError, SegmentFlushInfo, + SegmentFlushInfoConversionError, SegmentUuid, UpdateCollectionError, UpdateTenantError, + UpdateTenantResponse, VectorIndexConfiguration, }; use chroma_types::{ BatchGetCollectionSoftDeleteStatusError, BatchGetCollectionVersionFilePathsError, Collection, @@ -683,6 +684,17 @@ impl SysDb { SysDb::Test(_) => todo!(), } } + + pub async fn peek_schedule_by_collection_id( + &mut self, + collection_ids: &[CollectionUuid], + ) -> Result, PeekScheduleError> { + match self { + SysDb::Grpc(grpc) => grpc.peek_schedule_by_collection_id(collection_ids).await, + SysDb::Sqlite(_) => unimplemented!(), + SysDb::Test(test) => test.peek_schedule_by_collection_id(collection_ids).await, + } + } } #[derive(Clone, Debug)] @@ -1815,6 +1827,43 @@ impl GrpcSysDb { } } } + + async fn peek_schedule_by_collection_id( + &mut self, + collection_ids: &[CollectionUuid], + ) -> Result, PeekScheduleError> { + let req = chroma_proto::PeekScheduleByCollectionIdRequest { + collection_id: collection_ids.iter().map(|id| id.0.to_string()).collect(), + }; + let res = self + .client + .peek_schedule_by_collection_id(req) + .await + .map_err(|e| TonicError(e).boxed())?; + res.into_inner() + .schedule + .into_iter() + .map(|entry| entry.try_into()) + .collect::, ScheduleEntryConversionError>>() + .map_err(PeekScheduleError::Conversion) + } +} + +#[derive(Error, Debug)] +pub enum PeekScheduleError { + #[error("Failed to peek schedule")] + Internal(#[from] Box), + #[error("Failed to convert schedule entry")] + Conversion(#[from] ScheduleEntryConversionError), +} + +impl ChromaError for PeekScheduleError { + fn code(&self) -> ErrorCodes { + match self { + PeekScheduleError::Internal(e) => e.code(), + PeekScheduleError::Conversion(_) => ErrorCodes::Internal, + } + } } #[derive(Error, Debug)] diff --git a/rust/sysdb/src/test_sysdb.rs b/rust/sysdb/src/test_sysdb.rs index 59f3a32cdc8..abf3dae6d2e 100644 --- a/rust/sysdb/src/test_sysdb.rs +++ b/rust/sysdb/src/test_sysdb.rs @@ -662,4 +662,11 @@ impl TestSysDb { inner.tenant_resource_names.insert(tenant_id, resource_name); Ok(UpdateTenantResponse {}) } + + pub(crate) async fn peek_schedule_by_collection_id( + &mut self, + _collection_ids: &[CollectionUuid], + ) -> Result, crate::sysdb::PeekScheduleError> { + Ok(vec![]) + } } diff --git a/rust/types/Cargo.toml b/rust/types/Cargo.toml index 3f9d50f25d3..00384f160f5 100644 --- a/rust/types/Cargo.toml +++ b/rust/types/Cargo.toml @@ -8,6 +8,7 @@ path = "src/lib.rs" [dependencies] async-trait = { workspace = true } +chrono = { workspace = true } prost = { workspace = true } prost-types = { workspace = true } roaring = { workspace = true } diff --git a/rust/types/src/task.rs b/rust/types/src/task.rs index 0c5b3c953cd..0e7a4ba91d3 100644 --- a/rust/types/src/task.rs +++ b/rust/types/src/task.rs @@ -1,4 +1,6 @@ +use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; +use std::str::FromStr; use std::time::SystemTime; use utoipa::ToSchema; use uuid::Uuid; @@ -85,3 +87,69 @@ pub struct Task { /// Timestamp when the task was last updated pub updated_at: SystemTime, } + +/// ScheduleEntry represents a scheduled task run for a collection. +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ScheduleEntry { + pub collection_id: CollectionUuid, + pub task_id: Uuid, + pub task_run_nonce: Uuid, + pub when_to_run: Option>, +} + +impl TryFrom for ScheduleEntry { + type Error = ScheduleEntryConversionError; + + fn try_from(proto: crate::chroma_proto::ScheduleEntry) -> Result { + let collection_id = proto + .collection_id + .ok_or(ScheduleEntryConversionError::MissingField( + "collection_id".to_string(), + )) + .and_then(|id| { + CollectionUuid::from_str(&id).map_err(|_| { + ScheduleEntryConversionError::InvalidUuid("collection_id".to_string()) + }) + })?; + + let task_id = proto + .task_id + .ok_or(ScheduleEntryConversionError::MissingField( + "task_id".to_string(), + )) + .and_then(|id| { + Uuid::parse_str(&id) + .map_err(|_| ScheduleEntryConversionError::InvalidUuid("task_id".to_string())) + })?; + + let task_run_nonce = proto + .task_run_nonce + .ok_or(ScheduleEntryConversionError::MissingField( + "task_run_nonce".to_string(), + )) + .and_then(|nonce| { + Uuid::parse_str(&nonce).map_err(|_| { + ScheduleEntryConversionError::InvalidUuid("task_run_nonce".to_string()) + }) + })?; + + let when_to_run = proto + .when_to_run + .and_then(|ms| DateTime::from_timestamp_millis(ms as i64)); + + Ok(ScheduleEntry { + collection_id, + task_id, + task_run_nonce, + when_to_run, + }) + } +} + +#[derive(Debug, thiserror::Error)] +pub enum ScheduleEntryConversionError { + #[error("Missing required field: {0}")] + MissingField(String), + #[error("Invalid UUID for field: {0}")] + InvalidUuid(String), +}