diff --git a/cmd/cli/pkg/standalone/containers.go b/cmd/cli/pkg/standalone/containers.go index aed796ea1..8e6589e38 100644 --- a/cmd/cli/pkg/standalone/containers.go +++ b/cmd/cli/pkg/standalone/containers.go @@ -219,14 +219,7 @@ func ensureContainerStarted(ctx context.Context, dockerClient client.ContainerAP // CreateControllerContainer creates and starts a controller container. func CreateControllerContainer(ctx context.Context, dockerClient *client.Client, port uint16, host string, environment string, doNotTrack bool, gpu gpupkg.GPUSupport, modelStorageVolume string, printer StatusPrinter, engineKind types.ModelRunnerEngineKind) error { - // Determine the target image. - var imageName string - switch gpu { - case gpupkg.GPUSupportCUDA: - imageName = ControllerImage + ":" + controllerImageTagCUDA() - default: - imageName = ControllerImage + ":" + controllerImageTagCPU() - } + imageName := controllerImageName(gpu) // Set up the container configuration. portStr := strconv.Itoa(int(port)) diff --git a/cmd/cli/pkg/standalone/controller_image.go b/cmd/cli/pkg/standalone/controller_image.go new file mode 100644 index 000000000..2962e97b7 --- /dev/null +++ b/cmd/cli/pkg/standalone/controller_image.go @@ -0,0 +1,48 @@ +package standalone + +import ( + "os" + + gpupkg "github.com/docker/model-runner/cmd/cli/pkg/gpu" +) + +const ( + // ControllerImage is the image used for the controller container. + ControllerImage = "docker/model-runner" + // defaultControllerImageVersion is the image version used for the controller container + defaultControllerImageVersion = "latest" +) + +func controllerImageVersion() string { + if version, ok := os.LookupEnv("MODEL_RUNNER_CONTROLLER_VERSION"); ok && version != "" { + return version + } + return defaultControllerImageVersion +} + +func controllerImageVariant(detectedGPU gpupkg.GPUSupport) string { + if variant, ok := os.LookupEnv("MODEL_RUNNER_CONTROLLER_VARIANT"); ok { + if variant == "cpu" || variant == "generic" { + return "" + } + return variant + } + switch detectedGPU { + case gpupkg.GPUSupportCUDA: + return "cuda" + default: + return "" + } +} + +func fmtControllerImageName(repo, version, variant string) string { + tag := repo + ":" + version + if len(variant) > 0 { + tag += "-" + variant + } + return tag +} + +func controllerImageName(detectedGPU gpupkg.GPUSupport) string { + return fmtControllerImageName(ControllerImage, controllerImageVersion(), controllerImageVariant(detectedGPU)) +} diff --git a/cmd/cli/pkg/standalone/images.go b/cmd/cli/pkg/standalone/images.go index 4c9e8716e..48897b248 100644 --- a/cmd/cli/pkg/standalone/images.go +++ b/cmd/cli/pkg/standalone/images.go @@ -5,7 +5,6 @@ import ( "encoding/json" "fmt" "io" - "os" "github.com/docker/docker/api/types/image" "github.com/docker/docker/client" @@ -13,41 +12,9 @@ import ( gpupkg "github.com/docker/model-runner/cmd/cli/pkg/gpu" ) -const ( - // ControllerImage is the image used for the controller container. - ControllerImage = "docker/model-runner" - // defaultControllerImageTagCPU is the image tag used for the controller container - // when running with the CPU backend. - defaultControllerImageTagCPU = "latest" - // defaultControllerImageTagCUDA is the image tag used for the controller container - // when running with the CUDA GPU backend. - defaultControllerImageTagCUDA = "latest-cuda" -) - -func controllerImageTagCPU() string { - if version, ok := os.LookupEnv("MODEL_RUNNER_CONTROLLER_VERSION"); ok && version != "" { - return version - } - return defaultControllerImageTagCPU -} - -func controllerImageTagCUDA() string { - if version, ok := os.LookupEnv("MODEL_RUNNER_CONTROLLER_VERSION"); ok && version != "" { - return version + "-cuda" - } - return defaultControllerImageTagCUDA -} - // EnsureControllerImage ensures that the controller container image is pulled. func EnsureControllerImage(ctx context.Context, dockerClient client.ImageAPIClient, gpu gpupkg.GPUSupport, printer StatusPrinter) error { - // Determine the target image. - var imageName string - switch gpu { - case gpupkg.GPUSupportCUDA: - imageName = ControllerImage + ":" + controllerImageTagCUDA() - default: - imageName = ControllerImage + ":" + controllerImageTagCPU() - } + imageName := controllerImageName(gpu) // Perform the pull. out, err := dockerClient.ImagePull(ctx, imageName, image.PullOptions{}) @@ -80,13 +47,13 @@ func EnsureControllerImage(ctx context.Context, dockerClient client.ImageAPIClie // PruneControllerImages removes any unused controller container images. func PruneControllerImages(ctx context.Context, dockerClient client.ImageAPIClient, printer StatusPrinter) error { // Remove the standard image, if present. - imageNameCPU := ControllerImage + ":" + controllerImageTagCPU() + imageNameCPU := fmtControllerImageName(ControllerImage, controllerImageVersion(), "") if _, err := dockerClient.ImageRemove(ctx, imageNameCPU, image.RemoveOptions{}); err == nil { printer.Println("Removed image", imageNameCPU) } // Remove the CUDA GPU image, if present. - imageNameCUDA := ControllerImage + ":" + controllerImageTagCUDA() + imageNameCUDA := fmtControllerImageName(ControllerImage, controllerImageVersion(), "cuda") if _, err := dockerClient.ImageRemove(ctx, imageNameCUDA, image.RemoveOptions{}); err == nil { printer.Println("Removed image", imageNameCUDA) }