diff --git a/runs/service/run_service_test.go b/runs/service/run_service_test.go index 4da44baee1..fd3049c186 100644 --- a/runs/service/run_service_test.go +++ b/runs/service/run_service_test.go @@ -270,6 +270,146 @@ func TestGetRunDetails_TaskSpecLookupFails(t *testing.T) { assert.Error(t, err) } +func TestFillDefaultInputsForCreateRun(t *testing.T) { + inputs := &task.Inputs{ + Literals: []*task.NamedLiteral{ + { + Name: "x", + Value: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Primitive{ + Primitive: &core.Primitive{Value: &core.Primitive_Integer{Integer: 7}}, + }, + }, + }, + }, + }, + }, + } + + defaultInputs := []*task.NamedParameter{ + { + Name: "x", + Parameter: &core.Parameter{ + Behavior: &core.Parameter_Default{ + Default: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Primitive{ + Primitive: &core.Primitive{Value: &core.Primitive_Integer{Integer: 42}}, + }, + }, + }, + }, + }, + }, + }, + { + Name: "y", + Parameter: &core.Parameter{ + Behavior: &core.Parameter_Default{ + Default: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Primitive{ + Primitive: &core.Primitive{Value: &core.Primitive_StringValue{StringValue: "default"}}, + }, + }, + }, + }, + }, + }, + }, + } + + gotInputs := fillDefaultInputs(inputs, defaultInputs) + + assert.Len(t, gotInputs.Literals, 2) + got := make(map[string]*core.Literal, len(gotInputs.Literals)) + for _, nl := range gotInputs.Literals { + got[nl.Name] = nl.Value + } + assert.Equal(t, int64(7), got["x"].GetScalar().GetPrimitive().GetInteger(), "provided input should not be overwritten") + assert.Equal(t, "default", got["y"].GetScalar().GetPrimitive().GetStringValue(), "missing input should be filled from default") +} + +func TestCreateRunResponseIncludesMetadataAndStatus(t *testing.T) { + actionRepo := &repoMocks.ActionRepo{} + taskRepo := &repoMocks.TaskRepo{} + actionsClient := &mockActionsClient{} + repo := &repoMocks.Repository{} + store := &storageMocks.ComposedProtobufStore{} + dataStore := &storage.DataStore{ComposedProtobufStore: store} + + repo.On("ActionRepo").Return(actionRepo) + repo.On("TaskRepo").Maybe().Return(taskRepo) + + svc := &RunService{ + repo: repo, + actionsClient: actionsClient, + storagePrefix: "s3://flyte-data", + dataStore: dataStore, + } + + runID := &common.RunIdentifier{ + Org: "test-org", + Project: "test-project", + Domain: "test-domain", + Name: "rtest12345", + } + createdAt := time.Now().UTC().Truncate(time.Second) + + store.On("WriteProtobuf", mock.Anything, mock.Anything, storage.Options{}, mock.Anything).Return(nil).Once() + + actionRepo.On("CreateRun", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(&models.Run{ + Org: runID.Org, + Project: runID.Project, + Domain: runID.Domain, + Name: runID.Name, + Phase: int32(common.ActionPhase_ACTION_PHASE_QUEUED), + CreatedAt: createdAt, + Attempts: 1, + CacheStatus: core.CatalogCacheStatus_CACHE_DISABLED, + }, nil).Once() + + actionsClient.On("Enqueue", mock.Anything, mock.Anything). + Return(connect.NewResponse(&actions.EnqueueResponse{}), nil).Once() + + resp, err := svc.CreateRun(context.Background(), connect.NewRequest(&workflow.CreateRunRequest{ + Id: &workflow.CreateRunRequest_RunId{ + RunId: runID, + }, + Task: &workflow.CreateRunRequest_TaskSpec{ + TaskSpec: &task.TaskSpec{}, + }, + })) + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.NotNil(t, resp.Msg.GetRun()) + assert.NotNil(t, resp.Msg.GetRun().GetAction()) + assert.NotNil(t, resp.Msg.GetRun().GetAction().GetId()) + assert.Equal(t, runID.Name, resp.Msg.GetRun().GetAction().GetId().GetName()) + assert.NotNil(t, resp.Msg.GetRun().GetAction().GetMetadata()) + + status := resp.Msg.GetRun().GetAction().GetStatus() + assert.NotNil(t, status) + assert.Equal(t, common.ActionPhase_ACTION_PHASE_QUEUED, status.GetPhase()) + assert.NotNil(t, status.GetStartTime()) + assert.True(t, status.GetStartTime().AsTime().Equal(createdAt)) + assert.Equal(t, uint32(1), status.GetAttempts()) + assert.Equal(t, core.CatalogCacheStatus_CACHE_DISABLED, status.GetCacheStatus()) + assert.Nil(t, status.EndTime) + assert.Nil(t, status.DurationMs) + + repo.AssertExpectations(t) + actionRepo.AssertExpectations(t) + taskRepo.AssertExpectations(t) + actionsClient.AssertExpectations(t) + store.AssertExpectations(t) +} + func TestAbortRun(t *testing.T) { runID := &common.RunIdentifier{ Org: "test-org",