Skip to content
This repository was archived by the owner on Oct 9, 2023. It is now read-only.

Commit 1e0faab

Browse files
authored
Adding support for environment variables set on execution (#344)
* added environment variables to TaskExecutionMetadata Signed-off-by: Daniel Rammer <daniel@union.ai> * added support for environment variables Signed-off-by: Daniel Rammer <daniel@union.ai> * implemented unit tests and fixed linter Signed-off-by: Daniel Rammer <daniel@union.ai> --------- Signed-off-by: Daniel Rammer <daniel@union.ai>
1 parent 63e1e45 commit 1e0faab

File tree

21 files changed

+113
-8
lines changed

21 files changed

+113
-8
lines changed

go/tasks/pluginmachinery/core/exec_metadata.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,5 @@ type TaskExecutionMetadata interface {
4545
IsInterruptible() bool
4646
GetPlatformResources() *v1.ResourceRequirements
4747
GetInterruptibleFailureThreshold() uint32
48+
GetEnvironmentVariables() map[string]string
4849
}

go/tasks/pluginmachinery/core/mocks/task_execution_metadata.go

Lines changed: 34 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

go/tasks/pluginmachinery/flytek8s/container_helper.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ func AddFlyteCustomizationsToContainer(ctx context.Context, parameters template.
296296
}
297297
container.Args = modifiedArgs
298298

299-
container.Env = DecorateEnvVars(ctx, container.Env, parameters.TaskExecMetadata.GetTaskExecutionID())
299+
container.Env = DecorateEnvVars(ctx, container.Env, parameters.TaskExecMetadata.GetEnvironmentVariables(), parameters.TaskExecMetadata.GetTaskExecutionID())
300300

301301
if parameters.TaskExecMetadata.GetOverrides() != nil && parameters.TaskExecMetadata.GetOverrides().GetResources() != nil {
302302
res := parameters.TaskExecMetadata.GetOverrides().GetResources()

go/tasks/pluginmachinery/flytek8s/container_helper_test.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,9 @@ func TestToK8sContainer(t *testing.T) {
387387
mockTaskExecutionID.OnGetGeneratedName().Return("gen_name")
388388
mockTaskExecMetadata.OnGetTaskExecutionID().Return(&mockTaskExecutionID)
389389
mockTaskExecMetadata.OnGetPlatformResources().Return(&v1.ResourceRequirements{})
390+
mockTaskExecMetadata.OnGetEnvironmentVariables().Return(map[string]string{
391+
"foo": "bar",
392+
})
390393

391394
tCtx := &mocks.TaskExecutionContext{}
392395
tCtx.OnTaskExecutionMetadata().Return(&mockTaskExecMetadata)
@@ -419,6 +422,10 @@ func TestToK8sContainer(t *testing.T) {
419422
Name: "k",
420423
Value: "v",
421424
},
425+
{
426+
Name: "foo",
427+
Value: "bar",
428+
},
422429
}, container.Env)
423430
errs := validation.IsDNS1123Label(container.Name)
424431
assert.Nil(t, errs)
@@ -454,6 +461,7 @@ func getTemplateParametersForTest(resourceRequirements, platformResources *v1.Re
454461
mockOverrides.OnGetResources().Return(resourceRequirements)
455462
mockTaskExecMetadata.OnGetOverrides().Return(&mockOverrides)
456463
mockTaskExecMetadata.OnGetPlatformResources().Return(platformResources)
464+
mockTaskExecMetadata.OnGetEnvironmentVariables().Return(nil)
457465

458466
mockInputReader := mocks2.InputReader{}
459467
mockInputPath := storage.DataReference("s3://input/path")

go/tasks/pluginmachinery/flytek8s/k8s_resource_adds.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,13 @@ func GetExecutionEnvVars(id pluginsCore.TaskExecutionID) []v1.EnvVar {
115115
return envVars
116116
}
117117

118-
func DecorateEnvVars(ctx context.Context, envVars []v1.EnvVar, id pluginsCore.TaskExecutionID) []v1.EnvVar {
118+
func DecorateEnvVars(ctx context.Context, envVars []v1.EnvVar, taskEnvironmentVariables map[string]string, id pluginsCore.TaskExecutionID) []v1.EnvVar {
119119
envVars = append(envVars, GetContextEnvVars(ctx)...)
120120
envVars = append(envVars, GetExecutionEnvVars(id)...)
121121

122+
for k, v := range taskEnvironmentVariables {
123+
envVars = append(envVars, v1.EnvVar{Name: k, Value: v})
124+
}
122125
for k, v := range config.GetK8sPluginConfig().DefaultEnvVars {
123126
envVars = append(envVars, v1.EnvVar{Name: k, Value: v})
124127
}

go/tasks/pluginmachinery/flytek8s/k8s_resource_adds_test.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -266,19 +266,21 @@ func TestDecorateEnvVars(t *testing.T) {
266266
args args
267267
additionEnvVar map[string]string
268268
additionEnvVarFromEnv map[string]string
269+
executionEnvVar map[string]string
269270
want []v12.EnvVar
270271
}{
271-
{"no-additional", args{envVars: defaultEnv, id: mockTaskExecutionIdentifier{}}, emptyEnvVar, emptyEnvVar, expected},
272-
{"with-additional", args{envVars: defaultEnv, id: mockTaskExecutionIdentifier{}}, additionalEnv, emptyEnvVar, aggregated},
273-
{"from-env", args{envVars: defaultEnv, id: mockTaskExecutionIdentifier{}}, emptyEnvVar, envVarsFromEnv, aggregated},
272+
{"no-additional", args{envVars: defaultEnv, id: mockTaskExecutionIdentifier{}}, emptyEnvVar, emptyEnvVar, emptyEnvVar, expected},
273+
{"with-additional", args{envVars: defaultEnv, id: mockTaskExecutionIdentifier{}}, additionalEnv, emptyEnvVar, emptyEnvVar, aggregated},
274+
{"from-env", args{envVars: defaultEnv, id: mockTaskExecutionIdentifier{}}, emptyEnvVar, envVarsFromEnv, emptyEnvVar, aggregated},
275+
{"from-execution-metadata", args{envVars: defaultEnv, id: mockTaskExecutionIdentifier{}}, emptyEnvVar, emptyEnvVar, additionalEnv, aggregated},
274276
}
275277
for _, tt := range tests {
276278
t.Run(tt.name, func(t *testing.T) {
277279
assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{
278280
DefaultEnvVars: tt.additionEnvVar,
279281
DefaultEnvVarsFromEnv: tt.additionEnvVarFromEnv,
280282
}))
281-
if got := DecorateEnvVars(ctx, tt.args.envVars, tt.args.id); !reflect.DeepEqual(got, tt.want) {
283+
if got := DecorateEnvVars(ctx, tt.args.envVars, tt.executionEnvVar, tt.args.id); !reflect.DeepEqual(got, tt.want) {
282284
t.Errorf("DecorateEnvVars() = %v, want %v", got, tt.want)
283285
}
284286
})

go/tasks/pluginmachinery/flytek8s/pod_helper_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ func dummyTaskExecutionMetadata(resources *v1.ResourceRequirements) pluginsCore.
5757
taskExecutionMetadata.On("GetOverrides").Return(to)
5858
taskExecutionMetadata.On("IsInterruptible").Return(true)
5959
taskExecutionMetadata.OnGetPlatformResources().Return(&v1.ResourceRequirements{})
60+
taskExecutionMetadata.OnGetEnvironmentVariables().Return(nil)
6061
return taskExecutionMetadata
6162
}
6263

go/tasks/plugins/array/awsbatch/transformer.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ func UpdateBatchInputForArray(_ context.Context, batchInput *batch.SubmitJobInpu
127127

128128
func getEnvVarsForTask(ctx context.Context, execID pluginCore.TaskExecutionID, containerEnvVars []*core.KeyValuePair,
129129
defaultEnvVars map[string]string) []v1.EnvVar {
130-
envVars := flytek8s.DecorateEnvVars(ctx, flytek8s.ToK8sEnvVar(containerEnvVars), execID)
130+
envVars := flytek8s.DecorateEnvVars(ctx, flytek8s.ToK8sEnvVar(containerEnvVars), nil, execID)
131131
m := make(map[string]string, len(envVars))
132132
for _, envVar := range envVars {
133133
m[envVar.Name] = envVar.Value

go/tasks/plugins/array/k8s/management_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ func getMockTaskExecutionContext(ctx context.Context, parallelism int) *mocks.Ta
107107
tMeta.OnGetOwnerReference().Return(metav1.OwnerReference{})
108108
tMeta.OnGetPlatformResources().Return(&v1.ResourceRequirements{})
109109
tMeta.OnGetInterruptibleFailureThreshold().Return(2)
110+
tMeta.OnGetEnvironmentVariables().Return(nil)
110111

111112
ow := &mocks2.OutputWriter{}
112113
ow.OnGetOutputPrefixPath().Return("/prefix/")

go/tasks/plugins/k8s/dask/dask_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ func dummyDaskTaskContext(taskTemplate *core.TaskTemplate, resources *v1.Resourc
164164
taskExecutionMetadata.OnGetPlatformResources().Return(&testPlatformResources)
165165
taskExecutionMetadata.OnGetMaxAttempts().Return(uint32(1))
166166
taskExecutionMetadata.OnIsInterruptible().Return(isInterruptible)
167+
taskExecutionMetadata.OnGetEnvironmentVariables().Return(nil)
167168
overrides := &mocks.TaskOverrides{}
168169
overrides.OnGetResources().Return(resources)
169170
taskExecutionMetadata.OnGetOverrides().Return(overrides)

0 commit comments

Comments
 (0)