Skip to content

Commit f456686

Browse files
fixed handler
1 parent c190f00 commit f456686

File tree

4 files changed

+138
-15
lines changed

4 files changed

+138
-15
lines changed

bundle/direct/dresources/all_test.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"github.com/databricks/databricks-sdk-go/service/apps"
1717
"github.com/databricks/databricks-sdk-go/service/catalog"
1818
"github.com/databricks/databricks-sdk-go/service/database"
19+
"github.com/databricks/databricks-sdk-go/service/ml"
1920
"github.com/stretchr/testify/assert"
2021
"github.com/stretchr/testify/require"
2122
)
@@ -60,6 +61,18 @@ var testConfig map[string]any = map[string]any{
6061
Name: "main.myschema.my_synced_table",
6162
},
6263
},
64+
"models": &resources.MlflowModel{
65+
CreateModelRequest: ml.CreateModelRequest{
66+
Name: "my_mlflow_model",
67+
Description: "my_mlflow_model_description",
68+
Tags: []ml.ModelTag{
69+
{
70+
Key: "k1",
71+
Value: "v1",
72+
},
73+
},
74+
},
75+
},
6376
}
6477

6578
type prepareWorkspace func(client *databricks.WorkspaceClient) error

libs/testserver/fake_workspace.go

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"github.com/databricks/databricks-sdk-go/service/dashboards"
1717
"github.com/databricks/databricks-sdk-go/service/iam"
1818
"github.com/databricks/databricks-sdk-go/service/jobs"
19+
"github.com/databricks/databricks-sdk-go/service/ml"
1920
"github.com/databricks/databricks-sdk-go/service/pipelines"
2021
"github.com/databricks/databricks-sdk-go/service/sql"
2122
"github.com/databricks/databricks-sdk-go/service/workspace"
@@ -60,21 +61,22 @@ type FakeWorkspace struct {
6061
repoIdByPath map[string]int64
6162

6263
// normally, ids are not sequential, but we make them sequential for deterministic diff
63-
nextJobId int64
64-
nextJobRunId int64
65-
Jobs map[int64]jobs.Job
66-
JobRuns map[int64]jobs.Run
67-
JobPermissions map[string][]jobs.JobAccessControlRequest
68-
Pipelines map[string]pipelines.GetPipelineResponse
69-
PipelineUpdates map[string]bool
70-
Monitors map[string]catalog.MonitorInfo
71-
Apps map[string]apps.App
72-
Schemas map[string]catalog.SchemaInfo
73-
SchemasGrants map[string][]catalog.PrivilegeAssignment
74-
Volumes map[string]catalog.VolumeInfo
75-
Dashboards map[string]dashboards.Dashboard
76-
SqlWarehouses map[string]sql.GetWarehouseResponse
77-
Alerts map[string]sql.AlertV2
64+
nextJobId int64
65+
nextJobRunId int64
66+
Jobs map[int64]jobs.Job
67+
JobRuns map[int64]jobs.Run
68+
JobPermissions map[string][]jobs.JobAccessControlRequest
69+
Pipelines map[string]pipelines.GetPipelineResponse
70+
PipelineUpdates map[string]bool
71+
Monitors map[string]catalog.MonitorInfo
72+
Apps map[string]apps.App
73+
Schemas map[string]catalog.SchemaInfo
74+
SchemasGrants map[string][]catalog.PrivilegeAssignment
75+
Volumes map[string]catalog.VolumeInfo
76+
Dashboards map[string]dashboards.Dashboard
77+
SqlWarehouses map[string]sql.GetWarehouseResponse
78+
Alerts map[string]sql.AlertV2
79+
ModelRegistryModels map[string]ml.Model
7880

7981
Acls map[string][]workspace.AclItem
8082

@@ -172,6 +174,7 @@ func NewFakeWorkspace(url, token string) *FakeWorkspace {
172174
DatabaseCatalogs: map[string]database.DatabaseCatalog{},
173175
SyncedDatabaseTables: map[string]database.SyncedDatabaseTable{},
174176
Alerts: map[string]sql.AlertV2{},
177+
ModelRegistryModels: map[string]ml.Model{},
175178
}
176179
}
177180

libs/testserver/handlers.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,4 +521,21 @@ func AddDefaultHandlers(server *Server) {
521521
server.Handle("GET", "/api/2.0/permissions/jobs/{job_id}", func(req Request) any {
522522
return req.Workspace.JobsGetPermissions(req, req.Vars["job_id"])
523523
})
524+
525+
// Model registry models.
526+
server.Handle("POST", "/api/2.0/mlflow/registered-models/create", func(req Request) any {
527+
return req.Workspace.ModelRegistryCreateModel(req)
528+
})
529+
530+
server.Handle("GET", "/api/2.0/mlflow/databricks/registered-models/get", func(req Request) any {
531+
return req.Workspace.ModelRegistryGetModel(req)
532+
})
533+
534+
server.Handle("PATCH", "/api/2.0/mlflow/registered-models/update", func(req Request) any {
535+
return req.Workspace.ModelRegistryUpdateModel(req)
536+
})
537+
538+
server.Handle("DELETE", "/api/2.0/mlflow/registered-models/delete", func(req Request) any {
539+
return MapDelete(req.Workspace, req.Workspace.ModelRegistryModels, req.URL.Query().Get("name"))
540+
})
524541
}

libs/testserver/models.go

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
package testserver
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
7+
"github.com/databricks/databricks-sdk-go/service/ml"
8+
)
9+
10+
func (s *FakeWorkspace) ModelRegistryCreateModel(req Request) any {
11+
defer s.LockUnlock()()
12+
13+
var request ml.CreateModelRequest
14+
if err := json.Unmarshal(req.Body, &request); err != nil {
15+
return Response{
16+
StatusCode: 400,
17+
Body: map[string]string{"message": fmt.Sprintf("Failed to parse request: %s", err)},
18+
}
19+
}
20+
21+
// Create the model
22+
model := ml.Model{
23+
Name: request.Name,
24+
Description: request.Description,
25+
Tags: request.Tags,
26+
}
27+
28+
s.ModelRegistryModels[request.Name] = model
29+
30+
return Response{
31+
Body: ml.CreateModelResponse{
32+
RegisteredModel: &model,
33+
},
34+
}
35+
}
36+
37+
func (s *FakeWorkspace) ModelRegistryUpdateModel(req Request) any {
38+
defer s.LockUnlock()()
39+
40+
var request ml.UpdateModelRequest
41+
if err := json.Unmarshal(req.Body, &request); err != nil {
42+
return Response{
43+
StatusCode: 400,
44+
Body: map[string]string{"message": fmt.Sprintf("Failed to parse request: %s", err)},
45+
}
46+
}
47+
48+
existingModel, ok := s.ModelRegistryModels[request.Name]
49+
if !ok {
50+
return Response{
51+
StatusCode: 404,
52+
Body: map[string]string{"message": fmt.Sprintf("Model not found: %v", request.Name)},
53+
}
54+
}
55+
56+
// Update the model
57+
existingModel.Description = request.Description
58+
s.ModelRegistryModels[request.Name] = existingModel
59+
60+
return Response{
61+
Body: ml.UpdateModelResponse{
62+
RegisteredModel: &existingModel,
63+
},
64+
}
65+
}
66+
67+
func (s *FakeWorkspace) ModelRegistryGetModel(req Request) any {
68+
defer s.LockUnlock()()
69+
70+
name := req.URL.Query().Get("name")
71+
72+
model, ok := s.ModelRegistryModels[name]
73+
if !ok {
74+
return Response{
75+
StatusCode: 404,
76+
Body: map[string]string{"message": fmt.Sprintf("Model not found: %v", name)},
77+
}
78+
}
79+
80+
return Response{
81+
Body: ml.GetModelResponse{
82+
RegisteredModelDatabricks: &ml.ModelDatabricks{
83+
Name: model.Name,
84+
Description: model.Description,
85+
Tags: model.Tags,
86+
ForceSendFields: model.ForceSendFields,
87+
},
88+
},
89+
}
90+
}

0 commit comments

Comments
 (0)