Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions pkg/inference/models/api.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package models

import (
"bytes"
"encoding/json"
"fmt"

"github.com/docker/model-runner/pkg/distribution/types"
Expand Down Expand Up @@ -112,3 +114,35 @@ type Model struct {
// or ModelPack format (*modelpack.Model).
Config types.ModelConfig `json:"config"`
}

// UnmarshalJSON implements custom JSON unmarshaling for Model.
// This is necessary because Config is an interface type (types.ModelConfig),
// and Go's standard JSON decoder cannot unmarshal directly into an interface.
// We use json.RawMessage to defer parsing of the config field, allowing for
// future extension to support multiple ModelConfig implementations.
func (m *Model) UnmarshalJSON(data []byte) error {
type Alias Model
aux := struct {
*Alias
Config json.RawMessage `json:"config"`
}{
Alias: (*Alias)(m),
}

if err := json.Unmarshal(data, &aux); err != nil {
return err
}

if len(aux.Config) == 0 || bytes.Equal(aux.Config, []byte("null")) {
m.Config = nil
return nil
}

var cfg types.Config
if err := json.Unmarshal(aux.Config, &cfg); err != nil {
return err
}
m.Config = &cfg

return nil
}
318 changes: 318 additions & 0 deletions pkg/inference/models/api_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,318 @@
package models

import (
"encoding/json"
"testing"

"github.com/docker/model-runner/pkg/distribution/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestModelUnmarshalJSON(t *testing.T) {
tests := []struct {
name string
jsonData string
expected Model
}{
{
name: "full model with all config fields",
jsonData: `{
"id": "sha256:abc123",
"tags": ["ai/smollm2:latest", "ai/smollm2:1.7B-instruct-q4_K_M"],
"created": 1704067200,
"config": {
"format": "gguf",
"quantization": "Q4_K_M",
"parameters": "1.7B",
"architecture": "llama",
"size": "1.7B",
"context_size": 8192
}
}`,
expected: Model{
ID: "sha256:abc123",
Tags: []string{"ai/smollm2:latest", "ai/smollm2:1.7B-instruct-q4_K_M"},
Created: 1704067200,
Config: &types.Config{
Format: "gguf",
Quantization: "Q4_K_M",
Parameters: "1.7B",
Architecture: "llama",
Size: "1.7B",
ContextSize: int32Ptr(8192),
},
},
},
{
name: "model with minimal config",
jsonData: `{
"id": "sha256:def456",
"created": 1704067200,
"config": {
"format": "safetensors"
}
}`,
expected: Model{
ID: "sha256:def456",
Tags: nil,
Created: 1704067200,
Config: &types.Config{
Format: "safetensors",
},
},
},
{
name: "model with empty config",
jsonData: `{
"id": "sha256:ghi789",
"created": 1704067200,
"config": {}
}`,
expected: Model{
ID: "sha256:ghi789",
Tags: nil,
Created: 1704067200,
Config: &types.Config{},
},
},
{
name: "model with gguf metadata",
jsonData: `{
"id": "sha256:jkl012",
"tags": ["ai/testmodel:latest"],
"created": 1704067200,
"config": {
"format": "gguf",
"architecture": "llama",
"gguf": {
"llama.context_length": "4096",
"llama.embedding_length": "2048"
}
}
}`,
expected: Model{
ID: "sha256:jkl012",
Tags: []string{"ai/testmodel:latest"},
Created: 1704067200,
Config: &types.Config{
Format: "gguf",
Architecture: "llama",
GGUF: map[string]string{
"llama.context_length": "4096",
"llama.embedding_length": "2048",
},
},
},
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
var model Model
err := json.Unmarshal([]byte(tc.jsonData), &model)
require.NoError(t, err)

assert.Equal(t, tc.expected.ID, model.ID)
assert.Equal(t, tc.expected.Tags, model.Tags)
assert.Equal(t, tc.expected.Created, model.Created)

// Verify config is properly unmarshaled
require.NotNil(t, model.Config)
expectedConfig := tc.expected.Config.(*types.Config)
actualConfig, ok := model.Config.(*types.Config)
require.True(t, ok, "Config should be *types.Config")

assert.Equal(t, expectedConfig.Format, actualConfig.Format)
assert.Equal(t, expectedConfig.Quantization, actualConfig.Quantization)
assert.Equal(t, expectedConfig.Parameters, actualConfig.Parameters)
assert.Equal(t, expectedConfig.Architecture, actualConfig.Architecture)
assert.Equal(t, expectedConfig.Size, actualConfig.Size)
assert.Equal(t, expectedConfig.GGUF, actualConfig.GGUF)

if expectedConfig.ContextSize != nil {
require.NotNil(t, actualConfig.ContextSize)
assert.Equal(t, *expectedConfig.ContextSize, *actualConfig.ContextSize)
} else {
assert.Nil(t, actualConfig.ContextSize)
}
})
}
}

func TestModelUnmarshalJSONArray(t *testing.T) {
// This test simulates what the CLI does when listing models
jsonData := `[
{
"id": "sha256:abc123",
"tags": ["ai/model1:latest"],
"created": 1704067200,
"config": {
"format": "gguf",
"quantization": "Q4_K_M",
"architecture": "llama"
}
},
{
"id": "sha256:def456",
"tags": ["ai/model2:latest"],
"created": 1704067300,
"config": {
"format": "safetensors",
"architecture": "mistral"
}
}
]`

var models []Model
err := json.Unmarshal([]byte(jsonData), &models)
require.NoError(t, err)

require.Len(t, models, 2)

// Verify first model
assert.Equal(t, "sha256:abc123", models[0].ID)
assert.Equal(t, []string{"ai/model1:latest"}, models[0].Tags)
config0, ok := models[0].Config.(*types.Config)
require.True(t, ok)
assert.Equal(t, types.FormatGGUF, config0.Format)
assert.Equal(t, "Q4_K_M", config0.Quantization)
assert.Equal(t, "llama", config0.Architecture)

// Verify second model
assert.Equal(t, "sha256:def456", models[1].ID)
assert.Equal(t, []string{"ai/model2:latest"}, models[1].Tags)
config1, ok := models[1].Config.(*types.Config)
require.True(t, ok)
assert.Equal(t, types.FormatSafetensors, config1.Format)
assert.Equal(t, "mistral", config1.Architecture)
}

func TestModelJSONRoundTrip(t *testing.T) {
// Test that marshaling and unmarshaling preserves data
original := Model{
ID: "sha256:roundtrip123",
Tags: []string{"ai/testmodel:v1", "ai/testmodel:latest"},
Created: 1704067200,
Config: &types.Config{
Format: "gguf",
Quantization: "Q8_0",
Parameters: "7B",
Architecture: "llama",
Size: "7B",
ContextSize: int32Ptr(4096),
GGUF: map[string]string{
"llama.context_length": "4096",
},
},
}

// Marshal to JSON
jsonData, err := json.Marshal(original)
require.NoError(t, err)

// Unmarshal back
var unmarshaled Model
err = json.Unmarshal(jsonData, &unmarshaled)
require.NoError(t, err)

// Verify all fields match
assert.Equal(t, original.ID, unmarshaled.ID)
assert.Equal(t, original.Tags, unmarshaled.Tags)
assert.Equal(t, original.Created, unmarshaled.Created)

originalConfig := original.Config.(*types.Config)
unmarshaledConfig, ok := unmarshaled.Config.(*types.Config)
require.True(t, ok)

assert.Equal(t, originalConfig.Format, unmarshaledConfig.Format)
assert.Equal(t, originalConfig.Quantization, unmarshaledConfig.Quantization)
assert.Equal(t, originalConfig.Parameters, unmarshaledConfig.Parameters)
assert.Equal(t, originalConfig.Architecture, unmarshaledConfig.Architecture)
assert.Equal(t, originalConfig.Size, unmarshaledConfig.Size)
assert.Equal(t, originalConfig.GGUF, unmarshaledConfig.GGUF)
require.NotNil(t, unmarshaledConfig.ContextSize)
assert.Equal(t, *originalConfig.ContextSize, *unmarshaledConfig.ContextSize)
}

func TestModelUnmarshalJSONNullAndMissingConfig(t *testing.T) {
tests := []struct {
name string
jsonData string
}{
{
name: "missing config field",
jsonData: `{
"id": "sha256:abc123",
"tags": ["ai/smollm2:latest"],
"created": 1704067200
}`,
},
{
name: "explicit null config field",
jsonData: `{
"id": "sha256:abc123",
"tags": ["ai/smollm2:latest"],
"created": 1704067200,
"config": null
}`,
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
var model Model
err := json.Unmarshal([]byte(tc.jsonData), &model)
require.NoError(t, err)

// config should be nil for both missing and null cases
assert.Nil(t, model.Config)

// other fields should still be populated correctly
assert.Equal(t, "sha256:abc123", model.ID)
assert.Equal(t, []string{"ai/smollm2:latest"}, model.Tags)
assert.Equal(t, int64(1704067200), model.Created)
})
}
}

func TestModelUnmarshalJSONInvalidData(t *testing.T) {
tests := []struct {
name string
jsonData string
}{
{
name: "invalid JSON",
jsonData: `{invalid}`,
},
{
name: "wrong type for id",
jsonData: `{"id": 123, "config": {}}`,
},
{
name: "wrong type for tags",
jsonData: `{"id": "test", "tags": "not-an-array", "config": {}}`,
},
{
name: "config is string instead of object",
jsonData: `{"id": "test", "config": "not-an-object"}`,
},
{
name: "config is array instead of object",
jsonData: `{"id": "test", "config": [1, 2, 3]}`,
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
var model Model
err := json.Unmarshal([]byte(tc.jsonData), &model)
assert.Error(t, err)
})
}
}

// Helper function to create int32 pointers
func int32Ptr(i int32) *int32 {
return &i
}