Skip to content

Commit 7f583a6

Browse files
committed
Update GET and DELETE methos to filter by registry
Signed-off-by: ppadti <ppadti@redhat.com>
1 parent 270f4f6 commit 7f583a6

File tree

7 files changed

+121
-22
lines changed

7 files changed

+121
-22
lines changed

clients/ui/api/openapi/mod-arch.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2556,6 +2556,7 @@ components:
25562556
type: string
25572557
region:
25582558
type: string
2559+
description: Optional. AWS region for S3 source.
25592560
endpoint:
25602561
type: string
25612562
uri:

clients/ui/bff/internal/api/model_transfer_job_handler.go

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import (
1515
type ModelTransferJobListEnvelope Envelope[*models.ModelTransferJobList, None]
1616
type ModelTransferJobEnvelope Envelope[*models.ModelTransferJob, None]
1717

18-
func (app *App) GetAllModelTransferJobsHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
18+
func (app *App) GetAllModelTransferJobsHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
1919
ctx := r.Context()
2020

2121
namespace, ok := ctx.Value(constants.NamespaceHeaderParameterKey).(string)
@@ -30,7 +30,12 @@ func (app *App) GetAllModelTransferJobsHandler(w http.ResponseWriter, r *http.Re
3030
return
3131
}
3232

33-
transferJobs, err := app.repositories.ModelRegistry.GetAllModelTransferJobs(ctx, client, namespace)
33+
modelRegistryID := ps.ByName(ModelRegistryId)
34+
if modelRegistryID == "" {
35+
app.badRequestResponse(w, r, fmt.Errorf("model registry name is required"))
36+
return
37+
}
38+
transferJobs, err := app.repositories.ModelRegistry.GetAllModelTransferJobs(ctx, client, namespace, modelRegistryID)
3439
if err != nil {
3540
app.serverErrorResponse(w, r, err)
3641
return
@@ -186,7 +191,13 @@ func (app *App) DeleteModelTransferJobHandler(w http.ResponseWriter, r *http.Req
186191
return
187192
}
188193

189-
deletedJob, err := app.repositories.ModelRegistry.DeleteModelTransferJob(ctx, client, namespace, jobName)
194+
modelRegistryID := ps.ByName(ModelRegistryId)
195+
if modelRegistryID == "" {
196+
app.badRequestResponse(w, r, fmt.Errorf("model registry name is required"))
197+
return
198+
}
199+
200+
deletedJob, err := app.repositories.ModelRegistry.DeleteModelTransferJob(ctx, client, namespace, jobName, modelRegistryID)
190201
if err != nil {
191202
if errors.Is(err, repositories.ErrJobNotFound) {
192203
app.notFoundResponse(w, r)

clients/ui/bff/internal/api/model_transfer_job_handler_test.go

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1046,6 +1046,23 @@ var _ = Describe("TestModelTransferJob", func() {
10461046
Expect(rs.StatusCode).To(Equal(http.StatusBadRequest))
10471047
})
10481048

1049+
It("PATCH returns 404 when job exists but belongs to different registry", func() {
1050+
payload := ModelTransferJobEnvelope{
1051+
Data: &models.ModelTransferJob{
1052+
Name: "new-job-name",
1053+
},
1054+
}
1055+
_, rs, err := setupApiTest[Envelope[any, any]](
1056+
http.MethodPatch,
1057+
"/api/v1/model_registry/other-registry/model_transfer_jobs/transfer-job-001?namespace=kubeflow",
1058+
payload,
1059+
kubernetesMockedStaticClientFactory,
1060+
requestIdentity,
1061+
"kubeflow",
1062+
)
1063+
Expect(err).NotTo(HaveOccurred())
1064+
Expect(rs.StatusCode).To(Equal(http.StatusNotFound))
1065+
})
10491066
})
10501067

10511068
Context("deleting model transfer job", func() {
@@ -1074,5 +1091,45 @@ var _ = Describe("TestModelTransferJob", func() {
10741091
Expect(err).NotTo(HaveOccurred())
10751092
Expect(rs.StatusCode).To(Equal(http.StatusNotFound))
10761093
})
1094+
1095+
It("DELETE returns 404 when job exists but belongs to different registry", func() {
1096+
_, rs, err := setupApiTest[Envelope[any, any]](
1097+
http.MethodDelete,
1098+
"/api/v1/model_registry/other-registry/model_transfer_jobs/transfer-job-001?namespace=kubeflow",
1099+
nil,
1100+
kubernetesMockedStaticClientFactory,
1101+
requestIdentity,
1102+
"kubeflow",
1103+
)
1104+
Expect(err).NotTo(HaveOccurred())
1105+
Expect(rs.StatusCode).To(Equal(http.StatusNotFound))
1106+
})
1107+
})
1108+
})
1109+
1110+
var _ = Describe("TestModelTransferJob registry filtering", func() {
1111+
var requestIdentity kubernetes.RequestIdentity
1112+
1113+
BeforeEach(func() {
1114+
requestIdentity = kubernetes.RequestIdentity{
1115+
UserID: "user@example.com",
1116+
}
1117+
})
1118+
1119+
Context("GET list filtered by registry", func() {
1120+
It("GET list for other registry returns 200 with empty items", func() {
1121+
envelope, rs, err := setupApiTest[ModelTransferJobListEnvelope](
1122+
http.MethodGet,
1123+
"/api/v1/model_registry/other-registry/model_transfer_jobs?namespace=kubeflow",
1124+
nil,
1125+
kubernetesMockedStaticClientFactory,
1126+
requestIdentity,
1127+
"kubeflow",
1128+
)
1129+
Expect(err).NotTo(HaveOccurred())
1130+
Expect(rs.StatusCode).To(Equal(http.StatusOK))
1131+
Expect(envelope.Data).NotTo(BeNil())
1132+
Expect(envelope.Data.Items).To(BeEmpty())
1133+
})
10771134
})
10781135
})

clients/ui/bff/internal/integrations/kubernetes/client.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ type KubernetesClientInterface interface {
4545
DeleteSecret(ctx context.Context, namespace string, secretName string) error
4646

4747
// Model transfer jobs
48-
GetAllModelTransferJobs(ctx context.Context, namespace string) (*batchv1.JobList, error)
48+
GetAllModelTransferJobs(ctx context.Context, namespace string, modelRegistryID string) (*batchv1.JobList, error)
4949
CreateModelTransferJob(ctx context.Context, namespace string, job *batchv1.Job) error
5050
DeleteModelTransferJob(ctx context.Context, namespace string, jobName string) error
5151
CreateConfigMap(ctx context.Context, namespace string, configMap *corev1.ConfigMap) error

clients/ui/bff/internal/integrations/kubernetes/k8mocks/base_testenv.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -707,8 +707,9 @@ func createModelTransferJob(k8sClient kubernetes.Interface, ctx context.Context,
707707
Name: "transfer-job-001",
708708
Namespace: namespace,
709709
Labels: map[string]string{
710-
"modelregistry.kubeflow.org/job-type": "async-upload",
711-
"modelregistry.kubeflow.org/job-id": "001",
710+
"modelregistry.kubeflow.org/job-type": "async-upload",
711+
"modelregistry.kubeflow.org/job-id": "001",
712+
"modelregistry.kubeflow.org/model-registry-name": "model-registry",
712713
},
713714
Annotations: map[string]string{
714715
"modelregistry.kubeflow.org/registered-model-id": "1",

clients/ui/bff/internal/integrations/kubernetes/shared_k8s_client.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,14 +330,19 @@ func (kc *SharedClientLogic) UpdateCatalogSourceConfig(
330330
func (kc *SharedClientLogic) GetAllModelTransferJobs(
331331
ctx context.Context,
332332
namespace string,
333+
modelRegistryID string,
333334
) (*batchv1.JobList, error) {
334335
if namespace == "" {
335336
return &batchv1.JobList{}, fmt.Errorf("namespace cannot be empty")
336337
}
337338

339+
if modelRegistryID == "" {
340+
return &batchv1.JobList{}, fmt.Errorf("model registry name is required")
341+
}
342+
338343
sessionLogger := ctx.Value(constants.TraceLoggerKey).(*slog.Logger)
339344

340-
labelSelector := "modelregistry.kubeflow.org/job-type=async-upload"
345+
labelSelector := "modelregistry.kubeflow.org/job-type=async-upload,modelregistry.kubeflow.org/model-registry-name=" + modelRegistryID
341346

342347
modelTransferJobList, err := kc.Client.BatchV1().Jobs(namespace).List(ctx, metav1.ListOptions{
343348
LabelSelector: labelSelector,

clients/ui/bff/internal/repositories/model_transfer_jobs.go

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,12 @@ var (
2424
ErrJobValidationFailed = errors.New("validation failed")
2525
)
2626

27-
func (m *ModelRegistryRepository) GetAllModelTransferJobs(ctx context.Context, client k8s.KubernetesClientInterface, namespace string) (*models.ModelTransferJobList, error) {
28-
jobList, err := client.GetAllModelTransferJobs(ctx, namespace)
27+
func (m *ModelRegistryRepository) GetAllModelTransferJobs(ctx context.Context, client k8s.KubernetesClientInterface, namespace string, modelRegistryID string) (*models.ModelTransferJobList, error) {
28+
if modelRegistryID == "" {
29+
return &models.ModelTransferJobList{Items: nil, Size: 0, PageSize: 0}, nil
30+
}
31+
32+
jobList, err := client.GetAllModelTransferJobs(ctx, namespace, modelRegistryID)
2933
if err != nil {
3034
return nil, fmt.Errorf("failed to fetch model transfer jobs: %w", err)
3135
}
@@ -107,7 +111,7 @@ func (m *ModelRegistryRepository) createModelTransferJobResources(
107111
}
108112
}
109113

110-
job := buildK8sJob(jobName, namespace, jobID, payload, configMapName, sourceSecretName, destSecretName, modelRegistryAddress)
114+
job := buildK8sJob(jobName, namespace, jobID, payload, configMapName, sourceSecretName, destSecretName, modelRegistryAddress, modelRegistryID)
111115
if err := client.CreateModelTransferJob(ctx, namespace, job); err != nil {
112116
cleanupCreatedResources(ctx, client, namespace, configMapName, sourceSecretName, destSecretName)
113117
if apierrors.IsAlreadyExists(err) {
@@ -195,6 +199,11 @@ func (m *ModelRegistryRepository) UpdateModelTransferJob(
195199
oldAnnotations = map[string]string{}
196200
}
197201

202+
jobRegistry := oldJob.Labels["modelregistry.kubeflow.org/model-registry-name"]
203+
if jobRegistry != modelRegistryID {
204+
return nil, fmt.Errorf("%w: %s", ErrJobNotFound, oldJobName)
205+
}
206+
198207
oldConfigMapName := oldAnnotations["modelregistry.kubeflow.org/configmap-name"]
199208
if oldConfigMapName == "" {
200209
return nil, fmt.Errorf("old job missing required annotation: configmap-name (job may not have been created via this API)")
@@ -334,7 +343,10 @@ func (m *ModelRegistryRepository) UpdateModelTransferJob(
334343
return result, err
335344
}
336345

337-
func (m *ModelRegistryRepository) DeleteModelTransferJob(ctx context.Context, client k8s.KubernetesClientInterface, namespace string, jobName string) (*models.ModelTransferJob, error) {
346+
func (m *ModelRegistryRepository) DeleteModelTransferJob(ctx context.Context, client k8s.KubernetesClientInterface, namespace string, jobName string, modelRegistryID string) (*models.ModelTransferJob, error) {
347+
if modelRegistryID == "" {
348+
return nil, fmt.Errorf("%w: model registry name is required", ErrJobNotFound)
349+
}
338350
job, err := client.GetModelTransferJob(ctx, namespace, jobName)
339351
if err != nil {
340352
if apierrors.IsNotFound(err) {
@@ -343,6 +355,11 @@ func (m *ModelRegistryRepository) DeleteModelTransferJob(ctx context.Context, cl
343355
return nil, fmt.Errorf("failed to get job: %w", err)
344356
}
345357

358+
jobRegistry := job.Labels["modelregistry.kubeflow.org/model-registry-name"]
359+
if jobRegistry != modelRegistryID {
360+
return nil, fmt.Errorf("%w: %s", ErrJobNotFound, jobName)
361+
}
362+
346363
if err := client.DeleteModelTransferJob(ctx, namespace, jobName); err != nil {
347364
if apierrors.IsNotFound(err) {
348365
return nil, fmt.Errorf("%w: %s", ErrJobNotFound, jobName)
@@ -363,7 +380,7 @@ func (m *ModelRegistryRepository) getModelRegistryAddress(ctx context.Context, c
363380
}
364381

365382
func buildK8sJob(jobName, namespace, jobID string, payload models.ModelTransferJob,
366-
configMapName, sourceSecretName, destSecretName, modelRegistryAddress string) *batchv1.Job {
383+
configMapName, sourceSecretName, destSecretName, modelRegistryAddress, modelRegistryID string) *batchv1.Job {
367384

368385
backoffLimit := int32(3)
369386
baseImage := models.DefaultOCIBaseImage
@@ -449,21 +466,24 @@ func buildK8sJob(jobName, namespace, jobID string, payload models.ModelTransferJ
449466
annotations["modelregistry.kubeflow.org/source-bucket"] = payload.Source.Bucket
450467
annotations["modelregistry.kubeflow.org/source-key"] = payload.Source.Key
451468
annotations["modelregistry.kubeflow.org/source-secret"] = sourceSecretName
469+
annotations["modelregistry.kubeflow.org/model-registry-name"] = modelRegistryID
452470

453471
case models.ModelTransferJobSourceTypeURI:
454472
envVars = append(envVars,
455473
corev1.EnvVar{Name: "MODEL_SYNC_SOURCE_URI", Value: payload.Source.URI},
456474
)
457475
annotations["modelregistry.kubeflow.org/source-uri"] = payload.Source.URI
476+
annotations["modelregistry.kubeflow.org/model-registry-name"] = modelRegistryID
458477
}
459478

460479
return &batchv1.Job{
461480
ObjectMeta: metav1.ObjectMeta{
462481
Name: jobName,
463482
Namespace: namespace,
464483
Labels: map[string]string{
465-
"modelregistry.kubeflow.org/job-type": "async-upload",
466-
"modelregistry.kubeflow.org/job-id": jobID,
484+
"modelregistry.kubeflow.org/job-type": "async-upload",
485+
"modelregistry.kubeflow.org/job-id": jobID,
486+
"modelregistry.kubeflow.org/model-registry-name": modelRegistryID,
467487
},
468488
Annotations: annotations,
469489
},
@@ -606,6 +626,16 @@ func buildModelMetadataConfigMap(name, namespace string, payload models.ModelTra
606626
}
607627

608628
func buildSourceSecret(name, namespace string, payload models.ModelTransferJob, jobID string) *corev1.Secret {
629+
stringData := map[string]string{
630+
"AWS_ACCESS_KEY_ID": payload.Source.AwsAccessKeyId,
631+
"AWS_SECRET_ACCESS_KEY": payload.Source.AwsSecretAccessKey,
632+
"AWS_S3_ENDPOINT": payload.Source.Endpoint,
633+
"AWS_S3_BUCKET": payload.Source.Bucket,
634+
}
635+
636+
if payload.Source.Region != "" {
637+
stringData["AWS_DEFAULT_REGION"] = payload.Source.Region
638+
}
609639
return &corev1.Secret{
610640
ObjectMeta: metav1.ObjectMeta{
611641
Name: name,
@@ -615,14 +645,8 @@ func buildSourceSecret(name, namespace string, payload models.ModelTransferJob,
615645
"modelregistry.kubeflow.org/job-id": jobID,
616646
},
617647
},
618-
Type: corev1.SecretTypeOpaque,
619-
StringData: map[string]string{
620-
"AWS_ACCESS_KEY_ID": payload.Source.AwsAccessKeyId,
621-
"AWS_SECRET_ACCESS_KEY": payload.Source.AwsSecretAccessKey,
622-
"AWS_DEFAULT_REGION": payload.Source.Region,
623-
"AWS_S3_ENDPOINT": payload.Source.Endpoint,
624-
"AWS_S3_BUCKET": payload.Source.Bucket,
625-
},
648+
Type: corev1.SecretTypeOpaque,
649+
StringData: stringData,
626650
}
627651
}
628652

0 commit comments

Comments
 (0)