Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

63 changes: 63 additions & 0 deletions go/pkg/sysdb/coordinator/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,39 @@ func (s *Coordinator) DeleteTask(ctx context.Context, req *coordinatorpb.DeleteT
}, nil
}

// Mark a task run as complete and set the nonce for the next task run.
func (s *Coordinator) AdvanceTask(ctx context.Context, req *coordinatorpb.AdvanceTaskRequest) (*coordinatorpb.AdvanceTaskResponse, error) {
if req.TaskId == nil {
log.Error("AdvanceTask: task_id is required")
return nil, status.Errorf(codes.InvalidArgument, "task_id is required")
}

if req.TaskRunNonce == nil {
log.Error("AdvanceTask: task_run_nonce is required")
return nil, status.Errorf(codes.InvalidArgument, "task_run_nonce is required")
}

taskID, err := uuid.Parse(*req.TaskId)
if err != nil {
log.Error("AdvanceTask: invalid task_id", zap.Error(err))
return nil, status.Errorf(codes.InvalidArgument, "invalid task_id: %v", err)
}

taskRunNonce, err := uuid.Parse(*req.TaskRunNonce)
if err != nil {
log.Error("AdvanceTask: invalid task_run_nonce", zap.Error(err))
return nil, status.Errorf(codes.InvalidArgument, "invalid task_run_nonce: %v", err)
}

err = s.catalog.metaDomain.TaskDb(ctx).AdvanceTask(taskID, taskRunNonce)
if err != nil {
log.Error("AdvanceTask failed", zap.Error(err), zap.String("task_id", taskID.String()))
return nil, err
}

return &coordinatorpb.AdvanceTaskResponse{}, nil
}

// GetOperators retrieves all operators from the database
func (s *Coordinator) GetOperators(ctx context.Context, req *coordinatorpb.GetOperatorsRequest) (*coordinatorpb.GetOperatorsResponse, error) {
operators, err := s.catalog.metaDomain.OperatorDb(ctx).GetAll()
Expand All @@ -288,3 +321,33 @@ func (s *Coordinator) GetOperators(ctx context.Context, req *coordinatorpb.GetOp
Operators: protoOperators,
}, nil
}

// PeekScheduleByCollectionId gives, for a vector of collection IDs, a vector of schedule entries,
// including when to run and the nonce to use for said run.
func (s *Coordinator) PeekScheduleByCollectionId(ctx context.Context, req *coordinatorpb.PeekScheduleByCollectionIdRequest) (*coordinatorpb.PeekScheduleByCollectionIdResponse, error) {
tasks, err := s.catalog.metaDomain.TaskDb(ctx).PeekScheduleByCollectionId(req.CollectionId)
if err != nil {
log.Error("PeekScheduleByCollectionId failed", zap.Error(err))
return nil, err
}

scheduleEntries := make([]*coordinatorpb.ScheduleEntry, 0, len(tasks))
for _, task := range tasks {
task_id := task.ID.String()
entry := &coordinatorpb.ScheduleEntry{
CollectionId: &task.InputCollectionID,
TaskId: &task_id,
TaskRunNonce: proto.String(task.NextNonce.String()),
WhenToRun: nil,
}
if task.NextRun != nil {
whenToRun := uint64(task.NextRun.UnixMilli())
entry.WhenToRun = &whenToRun
}
scheduleEntries = append(scheduleEntries, entry)
}

return &coordinatorpb.PeekScheduleByCollectionIdResponse{
Schedule: scheduleEntries,
}, nil
}
24 changes: 24 additions & 0 deletions go/pkg/sysdb/grpc/task_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,18 @@ func (s *Server) DeleteTask(ctx context.Context, req *coordinatorpb.DeleteTaskRe
return res, nil
}

func (s *Server) AdvanceTask(ctx context.Context, req *coordinatorpb.AdvanceTaskRequest) (*coordinatorpb.AdvanceTaskResponse, error) {
log.Info("AdvanceTask", zap.String("collection_id", req.GetCollectionId()), zap.String("task_id", req.GetTaskId()))

res, err := s.coordinator.AdvanceTask(ctx, req)
if err != nil {
log.Error("AdvanceTask failed", zap.Error(err))
return nil, err
}

return res, nil
}

func (s *Server) GetOperators(ctx context.Context, req *coordinatorpb.GetOperatorsRequest) (*coordinatorpb.GetOperatorsResponse, error) {
log.Info("GetOperators")

Expand All @@ -63,3 +75,15 @@ func (s *Server) GetOperators(ctx context.Context, req *coordinatorpb.GetOperato

return res, nil
}

func (s *Server) PeekScheduleByCollectionId(ctx context.Context, req *coordinatorpb.PeekScheduleByCollectionIdRequest) (*coordinatorpb.PeekScheduleByCollectionIdResponse, error) {
log.Info("PeekScheduleByCollectionId", zap.Int64("num_collections", int64(len(req.CollectionId))))

res, err := s.coordinator.PeekScheduleByCollectionId(ctx, req)
if err != nil {
log.Error("PeekScheduleByCollectionId failed", zap.Error(err))
return nil, err
}

return res, nil
}
65 changes: 65 additions & 0 deletions go/pkg/sysdb/metastore/db/dao/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ package dao

import (
"errors"
"time"

"github.com/chroma-core/chroma/go/pkg/common"
"github.com/chroma-core/chroma/go/pkg/sysdb/metastore/db/dbmodel"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgconn"
"github.com/pingcap/log"
"go.uber.org/zap"
Expand Down Expand Up @@ -58,6 +60,55 @@ func (s *taskDb) GetByName(inputCollectionID string, taskName string) (*dbmodel.
return &task, nil
}

func (s *taskDb) GetByID(taskID uuid.UUID) (*dbmodel.Task, error) {
var task dbmodel.Task
err := s.db.
Where("task_id = ?", taskID).
Where("is_deleted = ?", false).
First(&task).Error

if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
log.Error("GetByID failed", zap.Error(err), zap.String("task_id", taskID.String()))
return nil, err
}
return &task, nil
}

func (s *taskDb) AdvanceTask(taskID uuid.UUID, taskRunNonce uuid.UUID) error {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI I'm going to remove/modify this later. The path that updates next_nonce also needs to transactionally update completion_offset.

nextNonce, err := uuid.NewV7()
if err != nil {
log.Error("AdvanceTask: failed to generate next nonce", zap.Error(err))
return err
}

now := time.Now()
result := s.db.Exec(`
UPDATE tasks
SET next_nonce = ?,
updated_at = GREATEST(updated_at, GREATEST(?, last_run)),
last_run = ?,
current_attempts = 0
WHERE task_id = ?
AND next_nonce = ?
AND is_deleted = false
`, nextNonce, now, now, taskID, taskRunNonce)

if result.Error != nil {
log.Error("AdvanceTask failed", zap.Error(result.Error), zap.String("task_id", taskID.String()))
return result.Error
}

if result.RowsAffected == 0 {
log.Warn("AdvanceTask: no rows affected", zap.String("task_id", taskID.String()), zap.String("task_run_nonce", taskRunNonce.String()))
return common.ErrTaskNotFound
}

return nil
}

func (s *taskDb) SoftDelete(inputCollectionID string, taskName string) error {
// Update task name and is_deleted in a single query
// Format: _deleted_<original_name>_<input_collection_id>_<task_id>
Expand All @@ -82,3 +133,17 @@ func (s *taskDb) SoftDelete(inputCollectionID string, taskName string) error {

return nil
}

func (s *taskDb) PeekScheduleByCollectionId(collectionIDs []string) ([]*dbmodel.Task, error) {
var tasks []*dbmodel.Task
err := s.db.
Where("input_collection_id IN ?", collectionIDs).
Where("is_deleted = ?", false).
Find(&tasks).Error

if err != nil {
log.Error("PeekScheduleByCollectionId failed", zap.Error(err))
return nil, err
}
return tasks, nil
}
138 changes: 138 additions & 0 deletions go/pkg/sysdb/metastore/db/dao/task_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,144 @@ func (suite *TaskDbTestSuite) TestTaskDb_DeleteAll() {
}
}

func (suite *TaskDbTestSuite) TestTaskDb_GetByID() {
taskID := uuid.New()
operatorID := dbmodel.OperatorRecordCounter
nextNonce, _ := uuid.NewV7()

task := &dbmodel.Task{
ID: taskID,
Name: "test-get-by-id-task",
OperatorID: operatorID,
InputCollectionID: "input_col_id",
OutputCollectionName: "output_col_name",
OperatorParams: "{}",
TenantID: "tenant1",
DatabaseID: "db1",
MinRecordsForTask: 100,
NextNonce: nextNonce,
}

err := suite.Db.Insert(task)
suite.Require().NoError(err)

retrieved, err := suite.Db.GetByID(taskID)
suite.Require().NoError(err)
suite.Require().NotNil(retrieved)
suite.Require().Equal(task.ID, retrieved.ID)
suite.Require().Equal(task.Name, retrieved.Name)
suite.Require().Equal(task.OperatorID, retrieved.OperatorID)

suite.db.Unscoped().Delete(&dbmodel.Task{}, "task_id = ?", task.ID)
Comment on lines +315 to +343
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[BestPractice]

Resource leak: The test creates and inserts a task but uses Unscoped().Delete() for cleanup, which bypasses GORM hooks and constraints. If the test fails before reaching cleanup, the task will remain in the database permanently. Use proper cleanup with defer or t.Cleanup():

func (suite *TaskDbTestSuite) TestTaskDb_GetByID() {
    // ... setup code ...
    
    // Clean up immediately after insert, regardless of test outcome
    suite.T().Cleanup(func() {
        suite.db.Unscoped().Delete(&dbmodel.Task{}, "task_id = ?", task.ID)
    })
    
    err := suite.Db.Insert(task)
    suite.Require().NoError(err)
    // ... rest of test ...
}
Context for Agents
[**BestPractice**]

Resource leak: The test creates and inserts a task but uses `Unscoped().Delete()` for cleanup, which bypasses GORM hooks and constraints. If the test fails before reaching cleanup, the task will remain in the database permanently. Use proper cleanup with `defer` or `t.Cleanup()`:

```go
func (suite *TaskDbTestSuite) TestTaskDb_GetByID() {
    // ... setup code ...
    
    // Clean up immediately after insert, regardless of test outcome
    suite.T().Cleanup(func() {
        suite.db.Unscoped().Delete(&dbmodel.Task{}, "task_id = ?", task.ID)
    })
    
    err := suite.Db.Insert(task)
    suite.Require().NoError(err)
    // ... rest of test ...
}
```

File: go/pkg/sysdb/metastore/db/dao/task_test.go
Line: 343

}

func (suite *TaskDbTestSuite) TestTaskDb_GetByID_NotFound() {
retrieved, err := suite.Db.GetByID(uuid.New())
suite.Require().NoError(err)
suite.Require().Nil(retrieved)
}

func (suite *TaskDbTestSuite) TestTaskDb_GetByID_IgnoresDeleted() {
taskID := uuid.New()
operatorID := dbmodel.OperatorRecordCounter
nextNonce, _ := uuid.NewV7()

task := &dbmodel.Task{
ID: taskID,
Name: "test-get-by-id-deleted",
OperatorID: operatorID,
InputCollectionID: "input1",
OutputCollectionName: "output1",
OperatorParams: "{}",
TenantID: "tenant1",
DatabaseID: "db1",
MinRecordsForTask: 100,
NextNonce: nextNonce,
}

err := suite.Db.Insert(task)
suite.Require().NoError(err)

err = suite.Db.SoftDelete("input1", "test-get-by-id-deleted")
suite.Require().NoError(err)

retrieved, err := suite.Db.GetByID(taskID)
suite.Require().NoError(err)
suite.Require().Nil(retrieved)

suite.db.Unscoped().Delete(&dbmodel.Task{}, "task_id = ?", task.ID)
}

func (suite *TaskDbTestSuite) TestTaskDb_AdvanceTask() {
taskID := uuid.New()
operatorID := dbmodel.OperatorRecordCounter
originalNonce, _ := uuid.NewV7()

task := &dbmodel.Task{
ID: taskID,
Name: "test-advance-task",
OperatorID: operatorID,
InputCollectionID: "input_col_id",
OutputCollectionName: "output_col_name",
OperatorParams: "{}",
TenantID: "tenant1",
DatabaseID: "db1",
MinRecordsForTask: 100,
NextNonce: originalNonce,
CurrentAttempts: 3,
}

err := suite.Db.Insert(task)
suite.Require().NoError(err)

err = suite.Db.AdvanceTask(taskID, originalNonce)
suite.Require().NoError(err)

retrieved, err := suite.Db.GetByID(taskID)
suite.Require().NoError(err)
suite.Require().NotNil(retrieved)
suite.Require().NotEqual(originalNonce, retrieved.NextNonce)
suite.Require().NotNil(retrieved.LastRun)
suite.Require().Equal(int32(0), retrieved.CurrentAttempts)

suite.db.Unscoped().Delete(&dbmodel.Task{}, "task_id = ?", task.ID)
}

func (suite *TaskDbTestSuite) TestTaskDb_AdvanceTask_InvalidNonce() {
taskID := uuid.New()
operatorID := dbmodel.OperatorRecordCounter
correctNonce, _ := uuid.NewV7()
wrongNonce, _ := uuid.NewV7()

task := &dbmodel.Task{
ID: taskID,
Name: "test-advance-task-wrong-nonce",
OperatorID: operatorID,
InputCollectionID: "input_col_id",
OutputCollectionName: "output_col_name",
OperatorParams: "{}",
TenantID: "tenant1",
DatabaseID: "db1",
MinRecordsForTask: 100,
NextNonce: correctNonce,
}

err := suite.Db.Insert(task)
suite.Require().NoError(err)

err = suite.Db.AdvanceTask(taskID, wrongNonce)
suite.Require().Error(err)
suite.Require().Equal(common.ErrTaskNotFound, err)

suite.db.Unscoped().Delete(&dbmodel.Task{}, "task_id = ?", task.ID)
}

func (suite *TaskDbTestSuite) TestTaskDb_AdvanceTask_NotFound() {
err := suite.Db.AdvanceTask(uuid.New(), uuid.Must(uuid.NewV7()))
suite.Require().Error(err)
suite.Require().Equal(common.ErrTaskNotFound, err)
}

// TestOperatorConstantsMatchSeededDatabase verifies that operator constants in
// dbmodel/constants.go match what we seed in the test database (which should match migrations).
// This catches drift between constants and migrations at test time.
Expand Down
3 changes: 3 additions & 0 deletions go/pkg/sysdb/metastore/db/dbmodel/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ func (v Task) TableName() string {
type ITaskDb interface {
Insert(task *Task) error
GetByName(inputCollectionID string, taskName string) (*Task, error)
GetByID(taskID uuid.UUID) (*Task, error)
AdvanceTask(taskID uuid.UUID, taskRunNonce uuid.UUID) error
SoftDelete(inputCollectionID string, taskName string) error
DeleteAll() error
PeekScheduleByCollectionId(collectionIDs []string) ([]*Task, error)
}
Loading
Loading