Skip to content

Commit dfb45ae

Browse files
authored
[Run Service] Add CreateRun default-input and response contract tests (#7058)
Signed-off-by: WangWang0226 <eeha8834@gmail.com>
1 parent 0e78c27 commit dfb45ae

File tree

1 file changed

+140
-0
lines changed

1 file changed

+140
-0
lines changed

runs/service/run_service_test.go

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,146 @@ func TestGetRunDetails_TaskSpecLookupFails(t *testing.T) {
270270
assert.Error(t, err)
271271
}
272272

273+
func TestFillDefaultInputsForCreateRun(t *testing.T) {
274+
inputs := &task.Inputs{
275+
Literals: []*task.NamedLiteral{
276+
{
277+
Name: "x",
278+
Value: &core.Literal{
279+
Value: &core.Literal_Scalar{
280+
Scalar: &core.Scalar{
281+
Value: &core.Scalar_Primitive{
282+
Primitive: &core.Primitive{Value: &core.Primitive_Integer{Integer: 7}},
283+
},
284+
},
285+
},
286+
},
287+
},
288+
},
289+
}
290+
291+
defaultInputs := []*task.NamedParameter{
292+
{
293+
Name: "x",
294+
Parameter: &core.Parameter{
295+
Behavior: &core.Parameter_Default{
296+
Default: &core.Literal{
297+
Value: &core.Literal_Scalar{
298+
Scalar: &core.Scalar{
299+
Value: &core.Scalar_Primitive{
300+
Primitive: &core.Primitive{Value: &core.Primitive_Integer{Integer: 42}},
301+
},
302+
},
303+
},
304+
},
305+
},
306+
},
307+
},
308+
{
309+
Name: "y",
310+
Parameter: &core.Parameter{
311+
Behavior: &core.Parameter_Default{
312+
Default: &core.Literal{
313+
Value: &core.Literal_Scalar{
314+
Scalar: &core.Scalar{
315+
Value: &core.Scalar_Primitive{
316+
Primitive: &core.Primitive{Value: &core.Primitive_StringValue{StringValue: "default"}},
317+
},
318+
},
319+
},
320+
},
321+
},
322+
},
323+
},
324+
}
325+
326+
gotInputs := fillDefaultInputs(inputs, defaultInputs)
327+
328+
assert.Len(t, gotInputs.Literals, 2)
329+
got := make(map[string]*core.Literal, len(gotInputs.Literals))
330+
for _, nl := range gotInputs.Literals {
331+
got[nl.Name] = nl.Value
332+
}
333+
assert.Equal(t, int64(7), got["x"].GetScalar().GetPrimitive().GetInteger(), "provided input should not be overwritten")
334+
assert.Equal(t, "default", got["y"].GetScalar().GetPrimitive().GetStringValue(), "missing input should be filled from default")
335+
}
336+
337+
func TestCreateRunResponseIncludesMetadataAndStatus(t *testing.T) {
338+
actionRepo := &repoMocks.ActionRepo{}
339+
taskRepo := &repoMocks.TaskRepo{}
340+
actionsClient := &mockActionsClient{}
341+
repo := &repoMocks.Repository{}
342+
store := &storageMocks.ComposedProtobufStore{}
343+
dataStore := &storage.DataStore{ComposedProtobufStore: store}
344+
345+
repo.On("ActionRepo").Return(actionRepo)
346+
repo.On("TaskRepo").Maybe().Return(taskRepo)
347+
348+
svc := &RunService{
349+
repo: repo,
350+
actionsClient: actionsClient,
351+
storagePrefix: "s3://flyte-data",
352+
dataStore: dataStore,
353+
}
354+
355+
runID := &common.RunIdentifier{
356+
Org: "test-org",
357+
Project: "test-project",
358+
Domain: "test-domain",
359+
Name: "rtest12345",
360+
}
361+
createdAt := time.Now().UTC().Truncate(time.Second)
362+
363+
store.On("WriteProtobuf", mock.Anything, mock.Anything, storage.Options{}, mock.Anything).Return(nil).Once()
364+
365+
actionRepo.On("CreateRun", mock.Anything, mock.Anything, mock.Anything, mock.Anything).
366+
Return(&models.Run{
367+
Org: runID.Org,
368+
Project: runID.Project,
369+
Domain: runID.Domain,
370+
Name: runID.Name,
371+
Phase: int32(common.ActionPhase_ACTION_PHASE_QUEUED),
372+
CreatedAt: createdAt,
373+
Attempts: 1,
374+
CacheStatus: core.CatalogCacheStatus_CACHE_DISABLED,
375+
}, nil).Once()
376+
377+
actionsClient.On("Enqueue", mock.Anything, mock.Anything).
378+
Return(connect.NewResponse(&actions.EnqueueResponse{}), nil).Once()
379+
380+
resp, err := svc.CreateRun(context.Background(), connect.NewRequest(&workflow.CreateRunRequest{
381+
Id: &workflow.CreateRunRequest_RunId{
382+
RunId: runID,
383+
},
384+
Task: &workflow.CreateRunRequest_TaskSpec{
385+
TaskSpec: &task.TaskSpec{},
386+
},
387+
}))
388+
assert.NoError(t, err)
389+
assert.NotNil(t, resp)
390+
assert.NotNil(t, resp.Msg.GetRun())
391+
assert.NotNil(t, resp.Msg.GetRun().GetAction())
392+
assert.NotNil(t, resp.Msg.GetRun().GetAction().GetId())
393+
assert.Equal(t, runID.Name, resp.Msg.GetRun().GetAction().GetId().GetName())
394+
assert.NotNil(t, resp.Msg.GetRun().GetAction().GetMetadata())
395+
396+
status := resp.Msg.GetRun().GetAction().GetStatus()
397+
assert.NotNil(t, status)
398+
assert.Equal(t, common.ActionPhase_ACTION_PHASE_QUEUED, status.GetPhase())
399+
assert.NotNil(t, status.GetStartTime())
400+
assert.True(t, status.GetStartTime().AsTime().Equal(createdAt))
401+
assert.Equal(t, uint32(1), status.GetAttempts())
402+
assert.Equal(t, core.CatalogCacheStatus_CACHE_DISABLED, status.GetCacheStatus())
403+
assert.Nil(t, status.EndTime)
404+
assert.Nil(t, status.DurationMs)
405+
406+
repo.AssertExpectations(t)
407+
actionRepo.AssertExpectations(t)
408+
taskRepo.AssertExpectations(t)
409+
actionsClient.AssertExpectations(t)
410+
store.AssertExpectations(t)
411+
}
412+
273413
func TestAbortRun(t *testing.T) {
274414
runID := &common.RunIdentifier{
275415
Org: "test-org",

0 commit comments

Comments
 (0)