@@ -29,6 +29,7 @@ import (
2929 "github.com/compose-spec/compose-go/v2/types"
3030 "github.com/containerd/errdefs"
3131 "github.com/docker/cli/cli-plugins/manager"
32+ "github.com/docker/docker/api/types/versions"
3233 "github.com/spf13/cobra"
3334 "golang.org/x/sync/errgroup"
3435
@@ -71,6 +72,7 @@ func (s *composeService) ensureModels(ctx context.Context, project *types.Projec
7172
7273type modelAPI struct {
7374 path string
75+ version string // cached plugin version
7476 env []string
7577 prepare func (ctx context.Context , cmd * exec.Cmd ) error
7678 cleanup func ()
@@ -170,7 +172,7 @@ func (m *modelAPI) ConfigureModel(ctx context.Context, config types.ModelConfig,
170172 }
171173 args = append (args , config .Model )
172174 // Only append RuntimeFlags if docker model CLI version is >= v1.0.6
173- if len (config .RuntimeFlags ) != 0 && m .supportsRuntimeFlags (ctx ) {
175+ if len (config .RuntimeFlags ) != 0 && m .supportsRuntimeFlags () {
174176 args = append (args , "--" )
175177 args = append (args , config .RuntimeFlags ... )
176178 }
@@ -279,113 +281,23 @@ func (m *modelAPI) ListModels(ctx context.Context) ([]string, error) {
279281 return availableModels , nil
280282}
281283
282- // getModelVersion retrieves the docker model CLI version
283- func (m * modelAPI ) getModelVersion (ctx context.Context ) (string , error ) {
284- cmd := exec .CommandContext (ctx , m .path , "version" )
285- err := m .prepare (ctx , cmd )
286- if err != nil {
287- return "" , err
288- }
289-
290- output , err := cmd .CombinedOutput ()
291- if err != nil {
292- return "" , fmt .Errorf ("error getting docker model version: %w" , err )
293- }
294-
295- // Parse output like: "Docker Model Runner version v1.0.4"
296- // We need to extract the version string (e.g., "v1.0.4")
297- lines := strings .Split (strings .TrimSpace (string (output )), "\n " )
298- for _ , line := range lines {
299- if strings .Contains (line , "version" ) {
300- parts := strings .Fields (line )
301- for i , part := range parts {
302- if part == "version" && i + 1 < len (parts ) {
303- return parts [i + 1 ], nil
304- }
305- }
306- }
307- }
308-
309- return "" , fmt .Errorf ("could not parse docker model version from output: %s" , string (output ))
310- }
311-
312284// supportsRuntimeFlags checks if the docker model version supports runtime flags
313285// Runtime flags are supported in version >= v1.0.6
314- func (m * modelAPI ) supportsRuntimeFlags (ctx context.Context ) bool {
315- versionStr , err := m .getModelVersion (ctx )
316- if err != nil {
317- // If we can't determine the version, don't append runtime flags to be safe
318- return false
319- }
320-
321- // Parse version strings
322- currentVersion , err := parseVersion (versionStr )
323- if err != nil {
324- return false
325- }
326-
327- minVersion , err := parseVersion ("1.0.6" )
328- if err != nil {
286+ func (m * modelAPI ) supportsRuntimeFlags () bool {
287+ // If version is not cached, don't append runtime flags to be safe
288+ if m .version == "" {
329289 return false
330290 }
331291
332- return ! currentVersion .LessThan (minVersion )
333- }
334-
335- // parseVersion parses a semantic version string
336- // Strips build metadata and prerelease suffixes (e.g., "1.0.6-dirty" or "1.0.6+build")
337- func parseVersion (versionStr string ) (* semVersion , error ) {
338- // Remove 'v' prefix if present
339- versionStr = strings .TrimPrefix (versionStr , "v" )
292+ // Strip 'v' prefix if present (e.g., "v1.0.6" -> "1.0.6")
293+ versionStr := strings .TrimPrefix (m .version , "v" )
340294
341295 // Strip build metadata or prerelease suffix after "-" or "+"
342- // Examples: "1.0.6-dirty" -> "1.0.6", "1.0.6+build" -> "1.0.6"
296+ // This is necessary because versions.LessThan treats "1.0.6-dirty" < "1.0.6" per semver rules
297+ // but we want to compare the base version numbers only
343298 if idx := strings .IndexAny (versionStr , "-+" ); idx != - 1 {
344299 versionStr = versionStr [:idx ]
345300 }
346301
347- parts := strings .Split (versionStr , "." )
348- if len (parts ) < 2 {
349- return nil , fmt .Errorf ("invalid version format: %s" , versionStr )
350- }
351-
352- var v semVersion
353- var err error
354-
355- v .major , err = strconv .Atoi (parts [0 ])
356- if err != nil {
357- return nil , fmt .Errorf ("invalid major version: %s" , parts [0 ])
358- }
359-
360- v .minor , err = strconv .Atoi (parts [1 ])
361- if err != nil {
362- return nil , fmt .Errorf ("invalid minor version: %s" , parts [1 ])
363- }
364-
365- if len (parts ) > 2 {
366- v .patch , err = strconv .Atoi (parts [2 ])
367- if err != nil {
368- return nil , fmt .Errorf ("invalid patch version: %s" , parts [2 ])
369- }
370- }
371-
372- return & v , nil
373- }
374-
375- // semVersion represents a semantic version
376- type semVersion struct {
377- major int
378- minor int
379- patch int
380- }
381-
382- // LessThan compares two semantic versions
383- func (v * semVersion ) LessThan (other * semVersion ) bool {
384- if v .major != other .major {
385- return v .major < other .major
386- }
387- if v .minor != other .minor {
388- return v .minor < other .minor
389- }
390- return v .patch < other .patch
302+ return ! versions .LessThan (versionStr , "1.0.6" )
391303}
0 commit comments