diff --git a/.gitignore b/.gitignore index 7bfe97f718e..869d0565017 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ bin/ coverage.out covdatafiles/ .DS_Store +vendor \ No newline at end of file diff --git a/pkg/bridge/convert.go b/pkg/bridge/convert.go index 8e9995b824f..d93d5b5dc21 100644 --- a/pkg/bridge/convert.go +++ b/pkg/bridge/convert.go @@ -55,6 +55,8 @@ func Convert(ctx context.Context, dockerCli command.Cli, project *types.Project, if err != nil { return err } + // Set default model_var and endpoint_var if missing + setDefaultModelVariablesIfMissing(project) // for user to rely on compose.yaml attribute names, not go struct ones, we marshall back into YAML raw, err := project.MarshalYAML(types.WithSecretContent) // Marshall to YAML @@ -222,3 +224,31 @@ func inspectWithPull(ctx context.Context, dockerCli command.Cli, imageName strin } return inspect, err } + +// setDefaultModelVariablesIfMissing sets default model_var and endpoint_var for services that use models +// but don't have these variables explicitly defined. +func setDefaultModelVariablesIfMissing(project *types.Project) { + for serviceName, service := range project.Services { + if len(service.Models) == 0 { + continue + } + for modelRef, modelConfig := range service.Models { + if modelConfig == nil { + modelConfig = &types.ServiceModelConfig{} + service.Models[modelRef] = modelConfig + } + + if modelConfig.ModelVariable == "" || modelConfig.EndpointVariable == "" { + defaultModelVar, defaultEndpointVar := utils.GetModelVariables(modelRef) + + if modelConfig.ModelVariable == "" { + modelConfig.ModelVariable = defaultModelVar + } + if modelConfig.EndpointVariable == "" { + modelConfig.EndpointVariable = defaultEndpointVar + } + } + } + project.Services[serviceName] = service + } +} diff --git a/pkg/compose/model.go b/pkg/compose/model.go index 7ae67321ca4..b42248d0800 100644 --- a/pkg/compose/model.go +++ b/pkg/compose/model.go @@ -24,12 +24,12 @@ import ( "os/exec" "slices" "strconv" - "strings" "github.com/compose-spec/compose-go/v2/types" "github.com/containerd/errdefs" "github.com/docker/cli/cli-plugins/manager" "github.com/docker/compose/v2/pkg/progress" + "github.com/docker/compose/v2/pkg/utils" "github.com/spf13/cobra" "golang.org/x/sync/errgroup" ) @@ -200,19 +200,20 @@ func (m *modelAPI) SetModelVariables(ctx context.Context, project *types.Project for _, service := range project.Services { for ref, modelConfig := range service.Models { model := project.Models[ref] - varPrefix := strings.ReplaceAll(strings.ToUpper(ref), "-", "_") + defaultModelVar, defaultEndpointVar := utils.GetModelVariables(ref) + var variable string if modelConfig != nil && modelConfig.ModelVariable != "" { variable = modelConfig.ModelVariable } else { - variable = varPrefix + "_MODEL" + variable = defaultModelVar } service.Environment[variable] = &model.Model if modelConfig != nil && modelConfig.EndpointVariable != "" { variable = modelConfig.EndpointVariable } else { - variable = varPrefix + "_URL" + variable = defaultEndpointVar } service.Environment[variable] = &status.Endpoint } diff --git a/pkg/utils/modelvar.go b/pkg/utils/modelvar.go new file mode 100644 index 00000000000..32d17a131f1 --- /dev/null +++ b/pkg/utils/modelvar.go @@ -0,0 +1,27 @@ +/* + Copyright 2020 Docker Compose CLI authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package utils + +import "strings" + +// GetModelVariables generates default model and endpoint variable names from a model reference. +// It converts the model reference to uppercase and replaces hyphens with underscores. +// Returns modelVariable (e.g., "AI_RUNNER_MODEL") and endpointVariable (e.g., "AI_RUNNER_URL"). +func GetModelVariables(modelRef string) (modelVariable, endpointVariable string) { + prefix := strings.ReplaceAll(strings.ToUpper(modelRef), "-", "_") + return prefix + "_MODEL", prefix + "_URL" +} diff --git a/pkg/utils/modelvar_test.go b/pkg/utils/modelvar_test.go new file mode 100644 index 00000000000..890bdffd989 --- /dev/null +++ b/pkg/utils/modelvar_test.go @@ -0,0 +1,91 @@ +/* + Copyright 2020 Docker Compose CLI authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package utils + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGetModelVariables(t *testing.T) { + tests := []struct { + name string + modelRef string + expectedModelVar string + expectedEndpointVar string + }{ + { + name: "simple name with underscore", + modelRef: "ai_runner", + expectedModelVar: "AI_RUNNER_MODEL", + expectedEndpointVar: "AI_RUNNER_URL", + }, + { + name: "name with hyphens", + modelRef: "ai-runner", + expectedModelVar: "AI_RUNNER_MODEL", + expectedEndpointVar: "AI_RUNNER_URL", + }, + { + name: "complex name with multiple hyphens", + modelRef: "my-llm-engine", + expectedModelVar: "MY_LLM_ENGINE_MODEL", + expectedEndpointVar: "MY_LLM_ENGINE_URL", + }, + { + name: "single word", + modelRef: "model", + expectedModelVar: "MODEL_MODEL", + expectedEndpointVar: "MODEL_URL", + }, + { + name: "mixed case", + modelRef: "AiRunner", + expectedModelVar: "AIRUNNER_MODEL", + expectedEndpointVar: "AIRUNNER_URL", + }, + { + name: "mixed case with hyphens", + modelRef: "Ai-Runner", + expectedModelVar: "AI_RUNNER_MODEL", + expectedEndpointVar: "AI_RUNNER_URL", + }, + { + name: "already uppercase with underscores", + modelRef: "AI_RUNNER", + expectedModelVar: "AI_RUNNER_MODEL", + expectedEndpointVar: "AI_RUNNER_URL", + }, + { + name: "lowercase simple", + modelRef: "airunner", + expectedModelVar: "AIRUNNER_MODEL", + expectedEndpointVar: "AIRUNNER_URL", + }, + } + + for _, tt := range tests { + t.Run( + tt.name, func(t *testing.T) { + modelVar, endpointVar := GetModelVariables(tt.modelRef) + assert.Equal(t, tt.expectedModelVar, modelVar, "modelVariable mismatch") + assert.Equal(t, tt.expectedEndpointVar, endpointVar, "endpointVariable mismatch") + }, + ) + } +}