Skip to content

Commit 5840316

Browse files
ilopezlunandeloof
authored andcommitted
Only append RuntimeFlags if docker model CLI version is >= v1.0.6
Signed-off-by: Ignacio López Luna <[email protected]>
1 parent 6aee7f8 commit 5840316

File tree

1 file changed

+113
-1
lines changed

1 file changed

+113
-1
lines changed

pkg/compose/model.go

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,8 @@ func (m *modelAPI) ConfigureModel(ctx context.Context, config types.ModelConfig,
169169
args = append(args, "--context-size", strconv.Itoa(config.ContextSize))
170170
}
171171
args = append(args, config.Model)
172-
if len(config.RuntimeFlags) != 0 {
172+
// Only append RuntimeFlags if docker model CLI version is >= v1.0.6
173+
if len(config.RuntimeFlags) != 0 && m.supportsRuntimeFlags(ctx) {
173174
args = append(args, "--")
174175
args = append(args, config.RuntimeFlags...)
175176
}
@@ -277,3 +278,114 @@ func (m *modelAPI) ListModels(ctx context.Context) ([]string, error) {
277278
}
278279
return availableModels, nil
279280
}
281+
282+
// getModelVersion retrieves the docker model CLI version
283+
func (m *modelAPI) getModelVersion(ctx context.Context) (string, error) {
284+
cmd := exec.CommandContext(ctx, m.path, "version")
285+
err := m.prepare(ctx, cmd)
286+
if err != nil {
287+
return "", err
288+
}
289+
290+
output, err := cmd.CombinedOutput()
291+
if err != nil {
292+
return "", fmt.Errorf("error getting docker model version: %w", err)
293+
}
294+
295+
// Parse output like: "Docker Model Runner version v1.0.4"
296+
// We need to extract the version string (e.g., "v1.0.4")
297+
lines := strings.Split(strings.TrimSpace(string(output)), "\n")
298+
for _, line := range lines {
299+
if strings.Contains(line, "version") {
300+
parts := strings.Fields(line)
301+
for i, part := range parts {
302+
if part == "version" && i+1 < len(parts) {
303+
return parts[i+1], nil
304+
}
305+
}
306+
}
307+
}
308+
309+
return "", fmt.Errorf("could not parse docker model version from output: %s", string(output))
310+
}
311+
312+
// supportsRuntimeFlags checks if the docker model version supports runtime flags
313+
// Runtime flags are supported in version >= v1.0.6
314+
func (m *modelAPI) supportsRuntimeFlags(ctx context.Context) bool {
315+
versionStr, err := m.getModelVersion(ctx)
316+
if err != nil {
317+
// If we can't determine the version, don't append runtime flags to be safe
318+
return false
319+
}
320+
321+
// Parse version strings
322+
currentVersion, err := parseVersion(versionStr)
323+
if err != nil {
324+
return false
325+
}
326+
327+
minVersion, err := parseVersion("1.0.6")
328+
if err != nil {
329+
return false
330+
}
331+
332+
return !currentVersion.LessThan(minVersion)
333+
}
334+
335+
// parseVersion parses a semantic version string
336+
// Strips build metadata and prerelease suffixes (e.g., "1.0.6-dirty" or "1.0.6+build")
337+
func parseVersion(versionStr string) (*semVersion, error) {
338+
// Remove 'v' prefix if present
339+
versionStr = strings.TrimPrefix(versionStr, "v")
340+
341+
// Strip build metadata or prerelease suffix after "-" or "+"
342+
// Examples: "1.0.6-dirty" -> "1.0.6", "1.0.6+build" -> "1.0.6"
343+
if idx := strings.IndexAny(versionStr, "-+"); idx != -1 {
344+
versionStr = versionStr[:idx]
345+
}
346+
347+
parts := strings.Split(versionStr, ".")
348+
if len(parts) < 2 {
349+
return nil, fmt.Errorf("invalid version format: %s", versionStr)
350+
}
351+
352+
var v semVersion
353+
var err error
354+
355+
v.major, err = strconv.Atoi(parts[0])
356+
if err != nil {
357+
return nil, fmt.Errorf("invalid major version: %s", parts[0])
358+
}
359+
360+
v.minor, err = strconv.Atoi(parts[1])
361+
if err != nil {
362+
return nil, fmt.Errorf("invalid minor version: %s", parts[1])
363+
}
364+
365+
if len(parts) > 2 {
366+
v.patch, err = strconv.Atoi(parts[2])
367+
if err != nil {
368+
return nil, fmt.Errorf("invalid patch version: %s", parts[2])
369+
}
370+
}
371+
372+
return &v, nil
373+
}
374+
375+
// semVersion represents a semantic version
376+
type semVersion struct {
377+
major int
378+
minor int
379+
patch int
380+
}
381+
382+
// LessThan compares two semantic versions
383+
func (v *semVersion) LessThan(other *semVersion) bool {
384+
if v.major != other.major {
385+
return v.major < other.major
386+
}
387+
if v.minor != other.minor {
388+
return v.minor < other.minor
389+
}
390+
return v.patch < other.patch
391+
}

0 commit comments

Comments
 (0)