@@ -29,6 +29,7 @@ import (
2929 "github.com/compose-spec/compose-go/v2/types"
3030 "github.com/containerd/errdefs"
3131 "github.com/docker/cli/cli-plugins/manager"
32+ "github.com/docker/docker/api/types/versions"
3233 "github.com/spf13/cobra"
3334 "golang.org/x/sync/errgroup"
3435
@@ -71,6 +72,7 @@ func (s *composeService) ensureModels(ctx context.Context, project *types.Projec
7172
7273type modelAPI struct {
7374 path string
75+ version string // cached plugin version
7476 env []string
7577 prepare func (ctx context.Context , cmd * exec.Cmd ) error
7678 cleanup func ()
@@ -89,7 +91,8 @@ func (s *composeService) newModelAPI(project *types.Project) (*modelAPI, error)
8991 return nil , err
9092 }
9193 return & modelAPI {
92- path : dockerModel .Path ,
94+ path : dockerModel .Path ,
95+ version : dockerModel .Version ,
9396 prepare : func (ctx context.Context , cmd * exec.Cmd ) error {
9497 return s .prepareShellOut (ctx , project .Environment , cmd )
9598 },
@@ -165,7 +168,7 @@ func (m *modelAPI) ConfigureModel(ctx context.Context, config types.ModelConfig,
165168 }
166169 args = append (args , config .Model )
167170 // Only append RuntimeFlags if docker model CLI version is >= v1.0.6
168- if len (config .RuntimeFlags ) != 0 && m .supportsRuntimeFlags (ctx ) {
171+ if len (config .RuntimeFlags ) != 0 && m .supportsRuntimeFlags () {
169172 args = append (args , "--" )
170173 args = append (args , config .RuntimeFlags ... )
171174 }
@@ -274,113 +277,23 @@ func (m *modelAPI) ListModels(ctx context.Context) ([]string, error) {
274277 return availableModels , nil
275278}
276279
277- // getModelVersion retrieves the docker model CLI version
278- func (m * modelAPI ) getModelVersion (ctx context.Context ) (string , error ) {
279- cmd := exec .CommandContext (ctx , m .path , "version" )
280- err := m .prepare (ctx , cmd )
281- if err != nil {
282- return "" , err
283- }
284-
285- output , err := cmd .CombinedOutput ()
286- if err != nil {
287- return "" , fmt .Errorf ("error getting docker model version: %w" , err )
288- }
289-
290- // Parse output like: "Docker Model Runner version v1.0.4"
291- // We need to extract the version string (e.g., "v1.0.4")
292- lines := strings .Split (strings .TrimSpace (string (output )), "\n " )
293- for _ , line := range lines {
294- if strings .Contains (line , "version" ) {
295- parts := strings .Fields (line )
296- for i , part := range parts {
297- if part == "version" && i + 1 < len (parts ) {
298- return parts [i + 1 ], nil
299- }
300- }
301- }
302- }
303-
304- return "" , fmt .Errorf ("could not parse docker model version from output: %s" , string (output ))
305- }
306-
307280// supportsRuntimeFlags checks if the docker model version supports runtime flags
308281// Runtime flags are supported in version >= v1.0.6
309- func (m * modelAPI ) supportsRuntimeFlags (ctx context.Context ) bool {
310- versionStr , err := m .getModelVersion (ctx )
311- if err != nil {
312- // If we can't determine the version, don't append runtime flags to be safe
313- return false
314- }
315-
316- // Parse version strings
317- currentVersion , err := parseVersion (versionStr )
318- if err != nil {
319- return false
320- }
321-
322- minVersion , err := parseVersion ("1.0.6" )
323- if err != nil {
282+ func (m * modelAPI ) supportsRuntimeFlags () bool {
283+ // If version is not cached, don't append runtime flags to be safe
284+ if m .version == "" {
324285 return false
325286 }
326287
327- return ! currentVersion .LessThan (minVersion )
328- }
329-
330- // parseVersion parses a semantic version string
331- // Strips build metadata and prerelease suffixes (e.g., "1.0.6-dirty" or "1.0.6+build")
332- func parseVersion (versionStr string ) (* semVersion , error ) {
333- // Remove 'v' prefix if present
334- versionStr = strings .TrimPrefix (versionStr , "v" )
288+ // Strip 'v' prefix if present (e.g., "v1.0.6" -> "1.0.6")
289+ versionStr := strings .TrimPrefix (m .version , "v" )
335290
336291 // Strip build metadata or prerelease suffix after "-" or "+"
337- // Examples: "1.0.6-dirty" -> "1.0.6", "1.0.6+build" -> "1.0.6"
292+ // This is necessary because versions.LessThan treats "1.0.6-dirty" < "1.0.6" per semver rules
293+ // but we want to compare the base version numbers only
338294 if idx := strings .IndexAny (versionStr , "-+" ); idx != - 1 {
339295 versionStr = versionStr [:idx ]
340296 }
341297
342- parts := strings .Split (versionStr , "." )
343- if len (parts ) < 2 {
344- return nil , fmt .Errorf ("invalid version format: %s" , versionStr )
345- }
346-
347- var v semVersion
348- var err error
349-
350- v .major , err = strconv .Atoi (parts [0 ])
351- if err != nil {
352- return nil , fmt .Errorf ("invalid major version: %s" , parts [0 ])
353- }
354-
355- v .minor , err = strconv .Atoi (parts [1 ])
356- if err != nil {
357- return nil , fmt .Errorf ("invalid minor version: %s" , parts [1 ])
358- }
359-
360- if len (parts ) > 2 {
361- v .patch , err = strconv .Atoi (parts [2 ])
362- if err != nil {
363- return nil , fmt .Errorf ("invalid patch version: %s" , parts [2 ])
364- }
365- }
366-
367- return & v , nil
368- }
369-
370- // semVersion represents a semantic version
371- type semVersion struct {
372- major int
373- minor int
374- patch int
375- }
376-
377- // LessThan compares two semantic versions
378- func (v * semVersion ) LessThan (other * semVersion ) bool {
379- if v .major != other .major {
380- return v .major < other .major
381- }
382- if v .minor != other .minor {
383- return v .minor < other .minor
384- }
385- return v .patch < other .patch
298+ return ! versions .LessThan (versionStr , "1.0.6" )
386299}
0 commit comments