Skip to content

Commit ab4fd40

Browse files
authored
[Fix] Fix model serving resource (#3690)
## Changes <!-- Summary of your changes that are easy to understand --> This PR fixes the provisioned throughput (PTP), external model (EM), and inference tables journeys via Terraform. We also remove invalid validation checks and planned invalid checks. We do not add new checks since the source-of-truth should be left to the API backend, which will throw an informative error as necessary when invalid parameters are provided. Should resolve #3676. ## Tests <!-- How is this tested? Please see the checklist below and also describe any other relevant tests --> We add new acceptance tests to cover the different endpoint types that can be created. - [X] `make test` run locally - [X] relevant change in `docs/` folder - [X] covered with integration tests in `internal/acceptance` - [x] relevant acceptance tests are passing - [X] using Go SDK
1 parent 165cda6 commit ab4fd40

File tree

3 files changed

+138
-125
lines changed

3 files changed

+138
-125
lines changed

internal/acceptance/model_serving_test.go

Lines changed: 128 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ func TestUcAccModelServingProvisionedThroughput(t *testing.T) {
9797
config {
9898
served_entities{
9999
name = "pt_model"
100-
entity_name = "system.ai.mistral_7b_instruct_v0_1"
100+
entity_name = "system.ai.mistral_7b_instruct_v0_2"
101101
entity_version = "1"
102102
min_provisioned_throughput = 0
103103
max_provisioned_throughput = 970
@@ -111,6 +111,133 @@ func TestUcAccModelServingProvisionedThroughput(t *testing.T) {
111111
}
112112
}
113113
`, name),
114+
}, step{
115+
Template: fmt.Sprintf(`
116+
resource "databricks_model_serving" "endpoint" {
117+
name = "%s"
118+
config {
119+
served_entities{
120+
name = "pt_model"
121+
entity_name = "system.ai.mistral_7b_instruct_v0_2"
122+
entity_version = "1"
123+
min_provisioned_throughput = 970
124+
max_provisioned_throughput = 1940
125+
}
126+
traffic_config {
127+
routes {
128+
served_model_name = "pt_model"
129+
traffic_percentage = 100
130+
}
131+
}
132+
}
133+
}
134+
`, name),
135+
}, step{
136+
Template: fmt.Sprintf(`
137+
resource "databricks_model_serving" "endpoint" {
138+
name = "%s"
139+
config {
140+
served_entities{
141+
name = "pt_model"
142+
entity_name = "system.ai.mistral_7b_instruct_v0_2"
143+
entity_version = "1"
144+
min_provisioned_throughput = 0
145+
max_provisioned_throughput = 1940
146+
}
147+
traffic_config {
148+
routes {
149+
served_model_name = "pt_model"
150+
traffic_percentage = 100
151+
}
152+
}
153+
}
154+
}
155+
`, name),
114156
},
115157
)
116158
}
159+
160+
func TestAccModelServingExternalModel(t *testing.T) {
161+
loadWorkspaceEnv(t)
162+
if isGcp(t) {
163+
skipf(t)("not available on GCP")
164+
}
165+
166+
name := fmt.Sprintf("terraform-test-model-serving-em-%s",
167+
acctest.RandStringFromCharSet(5, acctest.CharSetAlphaNum))
168+
scope_name := fmt.Sprintf("terraform-test-secret-scope-%s",
169+
acctest.RandStringFromCharSet(5, acctest.CharSetAlphaNum))
170+
workspaceLevel(t, step{
171+
Template: fmt.Sprintf(`
172+
resource "databricks_secret_scope" "scope" {
173+
name = "%s"
174+
}
175+
176+
resource "databricks_secret" "key" {
177+
key = "api_key"
178+
string_value = "fake-secret"
179+
scope = databricks_secret_scope.scope.id
180+
}
181+
182+
resource "databricks_model_serving" "endpoint" {
183+
name = "%s"
184+
config {
185+
served_entities {
186+
name = "prod_model"
187+
external_model {
188+
provider = "anthropic"
189+
name = "claude-2.0"
190+
task = "llm/v1/chat"
191+
anthropic_config {
192+
anthropic_api_key = databricks_secret.key.config_reference
193+
}
194+
}
195+
}
196+
traffic_config {
197+
routes {
198+
served_model_name = "prod_model"
199+
traffic_percentage = 100
200+
}
201+
}
202+
}
203+
}
204+
`, scope_name, name),
205+
},
206+
step{
207+
Template: fmt.Sprintf(`
208+
resource "databricks_secret_scope" "scope" {
209+
name = "%s"
210+
}
211+
212+
resource "databricks_secret" "key" {
213+
key = "api_key"
214+
string_value = "fake-secret"
215+
scope = databricks_secret_scope.scope.id
216+
}
217+
218+
resource "databricks_model_serving" "endpoint" {
219+
name = "%s"
220+
config {
221+
served_entities {
222+
name = "prod_model"
223+
external_model {
224+
provider = "openai"
225+
name = "gpt-4o"
226+
task = "llm/v1/chat"
227+
openai_config {
228+
openai_api_key = databricks_secret.key.config_reference
229+
}
230+
}
231+
}
232+
traffic_config {
233+
routes {
234+
served_model_name = "prod_model"
235+
traffic_percentage = 100
236+
}
237+
}
238+
}
239+
}
240+
`, scope_name, name),
241+
},
242+
)
243+
}

serving/resource_model_serving.go

Lines changed: 10 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,7 @@ package serving
22

33
import (
44
"context"
5-
"fmt"
65
"log"
7-
"slices"
8-
"strings"
96
"time"
107

118
"github.com/databricks/databricks-sdk-go/retries"
@@ -24,33 +21,24 @@ func ResourceModelServing() common.Resource {
2421
m["name"].ForceNew = true
2522
common.MustSchemaPath(m, "config", "served_models").ConflictsWith = []string{"config.served_entities"}
2623
common.MustSchemaPath(m, "config", "served_entities").ConflictsWith = []string{"config.served_models"}
24+
25+
common.MustSchemaPath(m, "config", "traffic_config").Computed = true
26+
common.MustSchemaPath(m, "config", "auto_capture_config", "table_name_prefix").Computed = true
27+
common.MustSchemaPath(m, "config", "auto_capture_config", "enabled").Computed = true
28+
common.MustSchemaPath(m, "config", "auto_capture_config", "catalog_name").ForceNew = true
29+
common.MustSchemaPath(m, "config", "auto_capture_config", "schema_name").ForceNew = true
30+
common.MustSchemaPath(m, "config", "auto_capture_config", "table_name_prefix").ForceNew = true
31+
32+
common.MustSchemaPath(m, "config", "served_models", "name").Computed = true
33+
common.MustSchemaPath(m, "config", "served_models", "workload_type").Computed = true
2734
common.MustSchemaPath(m, "config", "served_models", "scale_to_zero_enabled").Required = false
2835
common.MustSchemaPath(m, "config", "served_models", "scale_to_zero_enabled").Optional = true
2936
common.MustSchemaPath(m, "config", "served_models", "scale_to_zero_enabled").Default = true
30-
common.MustSchemaPath(m, "config", "served_models", "name").Computed = true
31-
common.MustSchemaPath(m, "config", "served_models", "workload_type").Default = "CPU"
32-
// TODO: `config.served_models.workload_type` should be a `Optional+Computed` field. Also consider this for other similar fields.
33-
// In this scenario, if a workspace does not have GPU serving, specifying `workload_type` = 'CPU' will get empty response from API.
34-
common.MustSchemaPath(m, "config", "served_models", "workload_type").DiffSuppressFunc = func(k, old, new string, d *schema.ResourceData) bool {
35-
return old == "" && new == "CPU"
36-
}
37-
common.MustSchemaPath(m, "config", "traffic_config").Computed = true
3837
common.MustSchemaPath(m, "config", "served_models").Deprecated = "Please use 'config.served_entities' instead of 'config.served_models'."
3938

40-
common.MustSchemaPath(m, "config", "served_entities", "scale_to_zero_enabled").Required = false
41-
common.MustSchemaPath(m, "config", "served_entities", "scale_to_zero_enabled").Optional = true
42-
common.MustSchemaPath(m, "config", "served_entities", "scale_to_zero_enabled").Default = false
4339
common.MustSchemaPath(m, "config", "served_entities", "name").Computed = true
44-
common.MustSchemaPath(m, "config", "served_entities", "workload_size").Optional = true
4540
common.MustSchemaPath(m, "config", "served_entities", "workload_size").Computed = true
46-
common.MustSchemaPath(m, "config", "served_entities", "workload_type").Optional = true
4741
common.MustSchemaPath(m, "config", "served_entities", "workload_type").Computed = true
48-
common.MustSchemaPath(m, "config", "served_entities", "workload_type").DiffSuppressFunc = func(k, old, new string, d *schema.ResourceData) bool {
49-
return old == "" && new == "CPU"
50-
}
51-
common.MustSchemaPath(m, "config", "auto_capture_config", "catalog_name").ForceNew = true
52-
common.MustSchemaPath(m, "config", "auto_capture_config", "schema_name").ForceNew = true
53-
common.MustSchemaPath(m, "config", "auto_capture_config", "table_name_prefix").ForceNew = true
5442

5543
m["serving_endpoint_id"] = &schema.Schema{
5644
Computed: true,
@@ -60,27 +48,13 @@ func ResourceModelServing() common.Resource {
6048
})
6149

6250
return common.Resource{
63-
CustomizeDiff: func(ctx context.Context, d *schema.ResourceDiff) error {
64-
old, new := d.GetChange("config.0.auto_capture_config.0.enabled")
65-
if old != nil && old == false && new == true {
66-
d.ForceNew("config.0.auto_capture_config.0.enabled")
67-
}
68-
err := validateExternalModelConfig(d)
69-
if err != nil {
70-
return err
71-
}
72-
return nil
73-
},
7451
Create: func(ctx context.Context, d *schema.ResourceData, c *common.DatabricksClient) error {
7552
w, err := c.WorkspaceClient()
7653
if err != nil {
7754
return err
7855
}
7956
var e serving.CreateServingEndpoint
8057
common.DataToStructPointer(d, s, &e)
81-
for i := range e.Config.ServedEntities {
82-
e.Config.ServedEntities[i].ForceSendFields = append(e.Config.ServedEntities[i].ForceSendFields, "ScaleToZeroEnabled", "MinProvisionedThroughput")
83-
}
8458
wait, err := w.ServingEndpoints.Create(ctx, e)
8559
if err != nil {
8660
return err
@@ -133,9 +107,6 @@ func ResourceModelServing() common.Resource {
133107
}
134108
var e serving.CreateServingEndpoint
135109
common.DataToStructPointer(d, s, &e)
136-
for i := range e.Config.ServedEntities {
137-
e.Config.ServedEntities[i].ForceSendFields = append(e.Config.ServedEntities[i].ForceSendFields, "ScaleToZeroEnabled")
138-
}
139110
e.Config.Name = e.Name
140111
_, err = w.ServingEndpoints.UpdateConfigAndWait(ctx, e.Config, retries.Timeout[serving.ServingEndpointDetailed](d.Timeout(schema.TimeoutUpdate)))
141112
return err
@@ -156,34 +127,3 @@ func ResourceModelServing() common.Resource {
156127
},
157128
}
158129
}
159-
160-
func validateExternalModelConfig(d *schema.ResourceDiff) error {
161-
_, e := d.GetOk("config.0.served_entities.0.external_model")
162-
provider, p := d.GetOk("config.0.served_entities.0.external_model.0.provider")
163-
164-
if !e || !p {
165-
return nil
166-
}
167-
168-
name := strings.ReplaceAll(provider.(string), "-", "_")
169-
config := d.Get(fmt.Sprintf("config.0.served_entities.0.external_model.0.%s_config", name)).([]interface{})
170-
171-
if len(config) == 0 {
172-
return fmt.Errorf("external_model provider is set to \"%s\" but \"%s_config\" block is missing", name, name)
173-
}
174-
175-
if configBlock, ok := d.Get("config.0.served_entities.0.external_model.0").(map[string]interface{}); ok {
176-
var found []string
177-
for key, value := range configBlock {
178-
if strings.HasSuffix(key, "_config") && len(value.([]interface{})) > 0 {
179-
found = append(found, key)
180-
}
181-
}
182-
slices.Sort(found)
183-
if len(found) > 1 {
184-
msg := strings.Join(found, ", ")
185-
return fmt.Errorf("only one external_model config block is allowed. Found: %s", msg)
186-
}
187-
}
188-
return nil
189-
}

serving/resource_model_serving_test.go

Lines changed: 0 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -632,57 +632,3 @@ func TestModelServingDelete_Error(t *testing.T) {
632632
ID: "test-endpoint",
633633
}.ExpectError(t, "Internal error happened")
634634
}
635-
636-
func TestModelServingExternalModelNoConfig(t *testing.T) {
637-
qa.ResourceFixture{
638-
Resource: ResourceModelServing(),
639-
HCL: `
640-
name = "test-endpoint"
641-
config {
642-
served_entities {
643-
name = "prod_model"
644-
entity_name = "ads1"
645-
entity_version = "2"
646-
external_model {
647-
name = "prod_external_model"
648-
provider = "ai21labs"
649-
task = "llm/v1/embeddings"
650-
}
651-
workload_size = "Small"
652-
scale_to_zero_enabled = true
653-
}
654-
}
655-
`,
656-
Create: true,
657-
}.ExpectError(t, "external_model provider is set to \"ai21labs\" but \"ai21labs_config\" block is missing")
658-
}
659-
660-
func TestModelServingExternalModelMultipleConfig(t *testing.T) {
661-
qa.ResourceFixture{
662-
Resource: ResourceModelServing(),
663-
HCL: `
664-
name = "test-endpoint"
665-
config {
666-
served_entities {
667-
name = "prod_model"
668-
entity_name = "ads1"
669-
entity_version = "2"
670-
external_model {
671-
name = "prod_external_model"
672-
provider = "ai21labs"
673-
task = "llm/v1/embeddings"
674-
ai21labs_config {
675-
ai21labs_api_key = "{{secrets/databricks/ai21labs_api_key}}"
676-
}
677-
anthropic_config {
678-
anthropic_api_key = "{{secrets/databricks/anthropic_api_key}}"
679-
}
680-
}
681-
workload_size = "Small"
682-
scale_to_zero_enabled = true
683-
}
684-
}
685-
`,
686-
Create: true,
687-
}.ExpectError(t, "only one external_model config block is allowed. Found: ai21labs_config, anthropic_config")
688-
}

0 commit comments

Comments
 (0)