@@ -270,6 +270,146 @@ func TestGetRunDetails_TaskSpecLookupFails(t *testing.T) {
270270 assert .Error (t , err )
271271}
272272
273+ func TestFillDefaultInputsForCreateRun (t * testing.T ) {
274+ inputs := & task.Inputs {
275+ Literals : []* task.NamedLiteral {
276+ {
277+ Name : "x" ,
278+ Value : & core.Literal {
279+ Value : & core.Literal_Scalar {
280+ Scalar : & core.Scalar {
281+ Value : & core.Scalar_Primitive {
282+ Primitive : & core.Primitive {Value : & core.Primitive_Integer {Integer : 7 }},
283+ },
284+ },
285+ },
286+ },
287+ },
288+ },
289+ }
290+
291+ defaultInputs := []* task.NamedParameter {
292+ {
293+ Name : "x" ,
294+ Parameter : & core.Parameter {
295+ Behavior : & core.Parameter_Default {
296+ Default : & core.Literal {
297+ Value : & core.Literal_Scalar {
298+ Scalar : & core.Scalar {
299+ Value : & core.Scalar_Primitive {
300+ Primitive : & core.Primitive {Value : & core.Primitive_Integer {Integer : 42 }},
301+ },
302+ },
303+ },
304+ },
305+ },
306+ },
307+ },
308+ {
309+ Name : "y" ,
310+ Parameter : & core.Parameter {
311+ Behavior : & core.Parameter_Default {
312+ Default : & core.Literal {
313+ Value : & core.Literal_Scalar {
314+ Scalar : & core.Scalar {
315+ Value : & core.Scalar_Primitive {
316+ Primitive : & core.Primitive {Value : & core.Primitive_StringValue {StringValue : "default" }},
317+ },
318+ },
319+ },
320+ },
321+ },
322+ },
323+ },
324+ }
325+
326+ gotInputs := fillDefaultInputs (inputs , defaultInputs )
327+
328+ assert .Len (t , gotInputs .Literals , 2 )
329+ got := make (map [string ]* core.Literal , len (gotInputs .Literals ))
330+ for _ , nl := range gotInputs .Literals {
331+ got [nl .Name ] = nl .Value
332+ }
333+ assert .Equal (t , int64 (7 ), got ["x" ].GetScalar ().GetPrimitive ().GetInteger (), "provided input should not be overwritten" )
334+ assert .Equal (t , "default" , got ["y" ].GetScalar ().GetPrimitive ().GetStringValue (), "missing input should be filled from default" )
335+ }
336+
337+ func TestCreateRunResponseIncludesMetadataAndStatus (t * testing.T ) {
338+ actionRepo := & repoMocks.ActionRepo {}
339+ taskRepo := & repoMocks.TaskRepo {}
340+ actionsClient := & mockActionsClient {}
341+ repo := & repoMocks.Repository {}
342+ store := & storageMocks.ComposedProtobufStore {}
343+ dataStore := & storage.DataStore {ComposedProtobufStore : store }
344+
345+ repo .On ("ActionRepo" ).Return (actionRepo )
346+ repo .On ("TaskRepo" ).Maybe ().Return (taskRepo )
347+
348+ svc := & RunService {
349+ repo : repo ,
350+ actionsClient : actionsClient ,
351+ storagePrefix : "s3://flyte-data" ,
352+ dataStore : dataStore ,
353+ }
354+
355+ runID := & common.RunIdentifier {
356+ Org : "test-org" ,
357+ Project : "test-project" ,
358+ Domain : "test-domain" ,
359+ Name : "rtest12345" ,
360+ }
361+ createdAt := time .Now ().UTC ().Truncate (time .Second )
362+
363+ store .On ("WriteProtobuf" , mock .Anything , mock .Anything , storage.Options {}, mock .Anything ).Return (nil ).Once ()
364+
365+ actionRepo .On ("CreateRun" , mock .Anything , mock .Anything , mock .Anything , mock .Anything ).
366+ Return (& models.Run {
367+ Org : runID .Org ,
368+ Project : runID .Project ,
369+ Domain : runID .Domain ,
370+ Name : runID .Name ,
371+ Phase : int32 (common .ActionPhase_ACTION_PHASE_QUEUED ),
372+ CreatedAt : createdAt ,
373+ Attempts : 1 ,
374+ CacheStatus : core .CatalogCacheStatus_CACHE_DISABLED ,
375+ }, nil ).Once ()
376+
377+ actionsClient .On ("Enqueue" , mock .Anything , mock .Anything ).
378+ Return (connect .NewResponse (& actions.EnqueueResponse {}), nil ).Once ()
379+
380+ resp , err := svc .CreateRun (context .Background (), connect .NewRequest (& workflow.CreateRunRequest {
381+ Id : & workflow.CreateRunRequest_RunId {
382+ RunId : runID ,
383+ },
384+ Task : & workflow.CreateRunRequest_TaskSpec {
385+ TaskSpec : & task.TaskSpec {},
386+ },
387+ }))
388+ assert .NoError (t , err )
389+ assert .NotNil (t , resp )
390+ assert .NotNil (t , resp .Msg .GetRun ())
391+ assert .NotNil (t , resp .Msg .GetRun ().GetAction ())
392+ assert .NotNil (t , resp .Msg .GetRun ().GetAction ().GetId ())
393+ assert .Equal (t , runID .Name , resp .Msg .GetRun ().GetAction ().GetId ().GetName ())
394+ assert .NotNil (t , resp .Msg .GetRun ().GetAction ().GetMetadata ())
395+
396+ status := resp .Msg .GetRun ().GetAction ().GetStatus ()
397+ assert .NotNil (t , status )
398+ assert .Equal (t , common .ActionPhase_ACTION_PHASE_QUEUED , status .GetPhase ())
399+ assert .NotNil (t , status .GetStartTime ())
400+ assert .True (t , status .GetStartTime ().AsTime ().Equal (createdAt ))
401+ assert .Equal (t , uint32 (1 ), status .GetAttempts ())
402+ assert .Equal (t , core .CatalogCacheStatus_CACHE_DISABLED , status .GetCacheStatus ())
403+ assert .Nil (t , status .EndTime )
404+ assert .Nil (t , status .DurationMs )
405+
406+ repo .AssertExpectations (t )
407+ actionRepo .AssertExpectations (t )
408+ taskRepo .AssertExpectations (t )
409+ actionsClient .AssertExpectations (t )
410+ store .AssertExpectations (t )
411+ }
412+
273413func TestAbortRun (t * testing.T ) {
274414 runID := & common.RunIdentifier {
275415 Org : "test-org" ,
0 commit comments