Skip to content

Commit 4194277

Browse files
Merge pull request #191 from mprahl/cherry-pick-platform-spec
Cherry-pick: Add support for platform specs on K8s native API (kubeflow#12016)
2 parents 56388b6 + a94537f commit 4194277

File tree

9 files changed

+902
-64
lines changed

9 files changed

+902
-64
lines changed

backend/src/apiserver/storage/pipeline_store_kubernetes_test.go

Lines changed: 60 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -279,9 +279,10 @@ func TestCreateK8sPipelineVersion(t *testing.T) {
279279
store := NewPipelineStoreKubernetes(getClient())
280280

281281
pipelineVersion := &model.PipelineVersion{
282-
Name: "Test Pipeline Version",
283-
PipelineId: DefaultFakePipelineIdTwo,
284-
Description: "Test Pipeline Version Description",
282+
Name: "Test Pipeline Version",
283+
PipelineId: DefaultFakePipelineIdTwo,
284+
Description: "Test Pipeline Version Description",
285+
PipelineSpec: getBasicPipelineSpecYAML(),
285286
}
286287

287288
_, err := store.CreatePipelineVersion(pipelineVersion)
@@ -355,13 +356,15 @@ func TestListK8sPipelineVersions_Pagination(t *testing.T) {
355356
store := NewPipelineStoreKubernetes(getClient())
356357

357358
pipelineVersion1 := &model.PipelineVersion{
358-
Name: "Test Pipeline Version 1",
359-
PipelineId: DefaultFakePipelineIdTwo,
359+
Name: "Test Pipeline Version 1",
360+
PipelineId: DefaultFakePipelineIdTwo,
361+
PipelineSpec: getBasicPipelineSpecYAML(),
360362
}
361363

362364
pipelineVersion2 := &model.PipelineVersion{
363-
Name: "Test Pipeline Version 2",
364-
PipelineId: DefaultFakePipelineIdTwo,
365+
Name: "Test Pipeline Version 2",
366+
PipelineId: DefaultFakePipelineIdTwo,
367+
PipelineSpec: getBasicPipelineSpecYAML(),
365368
}
366369

367370
_, err := store.CreatePipelineVersion(pipelineVersion1)
@@ -393,13 +396,15 @@ func TestListK8sPipelineVersions_Pagination_Descend(t *testing.T) {
393396
store := NewPipelineStoreKubernetes(getClient())
394397

395398
pipelineVersion1 := &model.PipelineVersion{
396-
Name: "Test Pipeline Version 1",
397-
PipelineId: DefaultFakePipelineIdTwo,
399+
Name: "Test Pipeline Version 1",
400+
PipelineId: DefaultFakePipelineIdTwo,
401+
PipelineSpec: getBasicPipelineSpecYAML(),
398402
}
399403

400404
pipelineVersion2 := &model.PipelineVersion{
401-
Name: "Test Pipeline Version 2",
402-
PipelineId: DefaultFakePipelineIdTwo,
405+
Name: "Test Pipeline Version 2",
406+
PipelineId: DefaultFakePipelineIdTwo,
407+
PipelineSpec: getBasicPipelineSpecYAML(),
403408
}
404409

405410
_, err := store.CreatePipelineVersion(pipelineVersion1)
@@ -480,13 +485,45 @@ func TestCreatePipelineAndPipelineVersion(t *testing.T) {
480485
Name: "Test Pipeline",
481486
}
482487
k8sPipelineVersion := &model.PipelineVersion{
483-
Name: "Test Pipeline Version",
488+
Name: "Test Pipeline Version",
489+
PipelineSpec: getBasicPipelineSpecYAML(),
484490
}
485491

486492
_, _, err := store.CreatePipelineAndPipelineVersion(k8sPipeline, k8sPipelineVersion)
487493
require.Nil(t, err, "Failed to create Pipeline: %v", err)
488494
}
489495

496+
// getBasicPipelineSpec returns a basic PipelineSpec for testing purposes
497+
func getBasicPipelineSpec() v2beta1.IRSpec {
498+
return v2beta1.IRSpec{
499+
Value: map[string]interface{}{
500+
"pipelineInfo": map[string]interface{}{
501+
"name": "test-pipeline",
502+
"displayName": "Test Pipeline",
503+
},
504+
"root": map[string]interface{}{
505+
"dag": map[string]interface{}{
506+
"tasks": map[string]interface{}{},
507+
},
508+
},
509+
"schemaVersion": "2.1.0",
510+
"sdkVersion": "kfp-2.13.0",
511+
},
512+
}
513+
}
514+
515+
// getBasicPipelineSpecYAML returns a basic PipelineSpec as YAML string for model.PipelineVersion objects
516+
func getBasicPipelineSpecYAML() string {
517+
return `pipelineInfo:
518+
name: test-pipeline
519+
displayName: Test Pipeline
520+
root:
521+
dag:
522+
tasks: {}
523+
schemaVersion: "2.1.0"
524+
sdkVersion: kfp-2.13.0`
525+
}
526+
490527
func getClient() (client.Client, client.Client) {
491528
scheme := runtime.NewScheme()
492529
err := v2beta1.AddToScheme(scheme)
@@ -510,6 +547,9 @@ func getClient() (client.Client, client.Client) {
510547
Name: "Test Pipeline Version",
511548
Namespace: "Test",
512549
},
550+
Spec: v2beta1.PipelineVersionSpec{
551+
PipelineSpec: getBasicPipelineSpec(),
552+
},
513553
}
514554

515555
pipelineVersion1 := &v2beta1.PipelineVersion{
@@ -520,13 +560,19 @@ func getClient() (client.Client, client.Client) {
520560
"pipelines.kubeflow.org/pipeline-id": DefaultFakePipelineId,
521561
},
522562
},
563+
Spec: v2beta1.PipelineVersionSpec{
564+
PipelineSpec: getBasicPipelineSpec(),
565+
},
523566
}
524567

525568
pipelineVersion2 := &v2beta1.PipelineVersion{
526569
ObjectMeta: metav1.ObjectMeta{
527570
Name: "Test Pipeline Version 2",
528571
Namespace: "Test",
529572
},
573+
Spec: v2beta1.PipelineVersionSpec{
574+
PipelineSpec: getBasicPipelineSpec(),
575+
},
530576
}
531577

532578
pipelineVersion3 := &v2beta1.PipelineVersion{
@@ -544,7 +590,8 @@ func getClient() (client.Client, client.Client) {
544590
},
545591
},
546592
Spec: v2beta1.PipelineVersionSpec{
547-
Description: "Test Pipeline Version 1 Description",
593+
Description: "Test Pipeline Version 1 Description",
594+
PipelineSpec: getBasicPipelineSpec(),
548595
},
549596
}
550597

backend/src/apiserver/template/v2_template.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,14 @@ func NewGenericScheduledWorkflow(modelJob *model.Job, ownerReferences []metav1.O
7979
}, nil
8080
}
8181

82+
func (t *V2Spec) PipelineSpec() *pipelinespec.PipelineSpec {
83+
return t.spec
84+
}
85+
86+
func (t *V2Spec) PlatformSpec() *pipelinespec.PlatformSpec {
87+
return t.platformSpec
88+
}
89+
8290
// Converts modelJob to ScheduledWorkflow.
8391
func (t *V2Spec) ScheduledWorkflow(modelJob *model.Job, ownerReferences []metav1.OwnerReference) (*scheduledworkflow.ScheduledWorkflow, error) {
8492
job := &pipelinespec.PipelineJob{}

backend/src/apiserver/webhook/pipelineversion_webhook.go

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ package webhook
1515

1616
import (
1717
"context"
18-
"encoding/json"
1918
"fmt"
2019
"net/http"
2120

@@ -91,18 +90,13 @@ func (p *PipelineVersionsWebhook) ValidateCreate(
9190
return nil, newBadRequestError(fmt.Sprintf("Expected a PipelineVersion object but got %T", pipelineVersion))
9291
}
9392

94-
_, err := p.getPipeline(ctx, pipelineVersion.Namespace, pipelineVersion.Spec.PipelineName)
93+
modelPipelineVersion, err := pipelineVersion.ToModel()
9594
if err != nil {
96-
return nil, err
97-
}
98-
99-
pipelineSpec, err := json.Marshal(pipelineVersion.Spec.PipelineSpec.Value)
100-
if err != nil {
101-
return nil, newBadRequestError(fmt.Sprintf("The pipeline spec is invalid JSON: %v", err))
95+
return nil, newBadRequestError(fmt.Sprintf("The pipeline spec is invalid: %v", err))
10296
}
10397

10498
// cache enabled or not doesn't matter in this context
105-
tmpl, err := template.NewV2SpecTemplate(pipelineSpec, false)
99+
tmpl, err := template.NewV2SpecTemplate([]byte(modelPipelineVersion.PipelineSpec), false)
106100
if err != nil {
107101
return nil, newBadRequestError(fmt.Sprintf("The pipeline spec is invalid: %v", err))
108102
}

backend/src/apiserver/webhook/pipelineversion_webhook_test.go

Lines changed: 55 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ func TestPipelineVersionWebhook_ValidateCreate(t *testing.T) {
7474
},
7575
Spec: k8sapi.PipelineVersionSpec{
7676
PipelineName: "test-pipeline",
77-
PipelineSpec: k8sapi.PipelineIRSpec{
77+
PipelineSpec: k8sapi.IRSpec{
7878
Value: json.RawMessage(validPipelineSpecJSON),
7979
},
8080
},
@@ -111,7 +111,7 @@ func TestPipelineVersionWebhook_ValidateCreate_InvalidPipelineSpec(t *testing.T)
111111
},
112112
Spec: k8sapi.PipelineVersionSpec{
113113
PipelineName: "test-pipeline",
114-
PipelineSpec: k8sapi.PipelineIRSpec{
114+
PipelineSpec: k8sapi.IRSpec{
115115
Value: json.RawMessage(invalidPipelineSpecJSON),
116116
},
117117
},
@@ -133,7 +133,7 @@ func TestPipelineVersionWebhook_ValidateUpdate(t *testing.T) {
133133
},
134134
Spec: k8sapi.PipelineVersionSpec{
135135
PipelineName: "test-pipeline",
136-
PipelineSpec: k8sapi.PipelineIRSpec{
136+
PipelineSpec: k8sapi.IRSpec{
137137
Value: json.RawMessage(validPipelineSpecJSON),
138138
},
139139
},
@@ -164,7 +164,7 @@ func TestPipelineVersionWebhook_ValidateUpdate(t *testing.T) {
164164
},
165165
Spec: k8sapi.PipelineVersionSpec{
166166
PipelineName: "test-pipeline",
167-
PipelineSpec: k8sapi.PipelineIRSpec{
167+
PipelineSpec: k8sapi.IRSpec{
168168
Value: json.RawMessage(updatedPipelineSpecJSON),
169169
},
170170
},
@@ -186,7 +186,7 @@ func TestPipelineVersionWebhook_ValidateUpdate_MetadataChangeAllowed(t *testing.
186186
},
187187
Spec: k8sapi.PipelineVersionSpec{
188188
PipelineName: "test-pipeline",
189-
PipelineSpec: k8sapi.PipelineIRSpec{
189+
PipelineSpec: k8sapi.IRSpec{
190190
Value: json.RawMessage(validPipelineSpecJSON),
191191
},
192192
},
@@ -201,7 +201,7 @@ func TestPipelineVersionWebhook_ValidateUpdate_MetadataChangeAllowed(t *testing.
201201
},
202202
Spec: k8sapi.PipelineVersionSpec{
203203
PipelineName: "test-pipeline",
204-
PipelineSpec: k8sapi.PipelineIRSpec{
204+
PipelineSpec: k8sapi.IRSpec{
205205
Value: json.RawMessage(validPipelineSpecJSON),
206206
},
207207
},
@@ -232,7 +232,7 @@ func TestPipelineVersionWebhook_MutatingUpdate_FixesOwnersRef(t *testing.T) {
232232
},
233233
Spec: k8sapi.PipelineVersionSpec{
234234
PipelineName: "test-pipeline",
235-
PipelineSpec: k8sapi.PipelineIRSpec{
235+
PipelineSpec: k8sapi.IRSpec{
236236
Value: json.RawMessage(validPipelineSpecJSON),
237237
},
238238
},
@@ -245,3 +245,51 @@ func TestPipelineVersionWebhook_MutatingUpdate_FixesOwnersRef(t *testing.T) {
245245
require.Equal(t, pipelineVersion.OwnerReferences[0].Name, "test-pipeline")
246246
require.True(t, *pipelineVersion.OwnerReferences[0].BlockOwnerDeletion)
247247
}
248+
249+
func TestPipelineVersionWebhook_ValidateCreate_WithPlatformSpec(t *testing.T) {
250+
pipelineWebhook, _ := setupPipelineWebhookTest(t)
251+
252+
validPipelineSpec := map[string]interface{}{
253+
"pipelineInfo": map[string]interface{}{
254+
"name": "test-pipeline-v1",
255+
"description": "A simple test pipeline",
256+
},
257+
"root": map[string]interface{}{
258+
"dag": map[string]interface{}{
259+
"tasks": map[string]interface{}{},
260+
},
261+
},
262+
"schemaVersion": "2.1.0",
263+
"sdkVersion": "kfp-2.11.0",
264+
}
265+
266+
validPlatformSpec := map[string]interface{}{
267+
"platforms": map[string]interface{}{
268+
"kubernetes": map[string]interface{}{
269+
"pipelineConfig": map[string]interface{}{
270+
"workspace": map[string]interface{}{
271+
"size": "10Gi",
272+
},
273+
},
274+
},
275+
},
276+
}
277+
278+
pipelineVersion := &k8sapi.PipelineVersion{
279+
ObjectMeta: metav1.ObjectMeta{
280+
Name: "test-pipeline-v1",
281+
Namespace: "default",
282+
},
283+
Spec: k8sapi.PipelineVersionSpec{
284+
PipelineName: "test-pipeline",
285+
PipelineSpec: k8sapi.IRSpec{
286+
Value: validPipelineSpec,
287+
},
288+
PlatformSpec: &k8sapi.IRSpec{
289+
Value: validPlatformSpec,
290+
},
291+
},
292+
}
293+
_, err := pipelineWebhook.ValidateCreate(context.TODO(), pipelineVersion)
294+
assert.NoError(t, err, "Expected no error for a valid PipelineVersion with platform spec")
295+
}

0 commit comments

Comments
 (0)