diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go index 54a00abc1..e476a8ee6 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go @@ -217,8 +217,8 @@ func TestPrefixPluginChatCompletions(t *testing.T) { Body: &types.LLMRequestBody{ ChatCompletions: &types.ChatCompletionsRequest{ Messages: []types.Message{ - {Role: "user", Content: "hello world"}, - {Role: "assistant", Content: "hi there"}, + {Role: "user", Content: types.Content{Raw: "hello world"}}, + {Role: "assistant", Content: types.Content{Raw: "hi there"}}, }, }, }, @@ -252,8 +252,8 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) { Body: &types.LLMRequestBody{ ChatCompletions: &types.ChatCompletionsRequest{ Messages: []types.Message{ - {Role: "system", Content: "You are a helpful assistant"}, - {Role: "user", Content: "Hello, how are you?"}, + {Role: "system", Content: types.Content{Raw: "You are a helpful assistant"}}, + {Role: "user", Content: types.Content{Raw: "Hello, how are you?"}}, }, }, }, @@ -285,10 +285,10 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) { Body: &types.LLMRequestBody{ ChatCompletions: &types.ChatCompletionsRequest{ Messages: []types.Message{ - {Role: "system", Content: "You are a helpful assistant"}, - {Role: "user", Content: "Hello, how are you?"}, - {Role: "assistant", Content: "I'm doing well, thank you! How can I help you today?"}, - {Role: "user", Content: "Can you explain how prefix caching works?"}, + {Role: "system", Content: types.Content{Raw: "You are a helpful assistant"}}, + {Role: "user", Content: types.Content{Raw: "Hello, how are you?"}}, + {Role: "assistant", Content: types.Content{Raw: "I'm doing well, thank you! How can I help you today?"}}, + {Role: "user", Content: types.Content{Raw: "Can you explain how prefix caching works?"}}, }, }, }, @@ -318,12 +318,12 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) { Body: &types.LLMRequestBody{ ChatCompletions: &types.ChatCompletionsRequest{ Messages: []types.Message{ - {Role: "system", Content: "You are a helpful assistant"}, - {Role: "user", Content: "Hello, how are you?"}, - {Role: "assistant", Content: "I'm doing well, thank you! How can I help you today?"}, - {Role: "user", Content: "Can you explain how prefix caching works?"}, - {Role: "assistant", Content: "Prefix caching is a technique where..."}, - {Role: "user", Content: "That's very helpful, thank you!"}, + {Role: "system", Content: types.Content{Raw: "You are a helpful assistant"}}, + {Role: "user", Content: types.Content{Raw: "Hello, how are you?"}}, + {Role: "assistant", Content: types.Content{Raw: "I'm doing well, thank you! How can I help you today?"}}, + {Role: "user", Content: types.Content{Raw: "Can you explain how prefix caching works?"}}, + {Role: "assistant", Content: types.Content{Raw: "Prefix caching is a technique where..."}}, + {Role: "user", Content: types.Content{Raw: "That's very helpful, thank you!"}}, }, }, }, @@ -443,7 +443,7 @@ func BenchmarkPrefixPluginChatCompletionsStress(b *testing.B) { b.Run(fmt.Sprintf("messages_%d_length_%d", scenario.messageCount, scenario.messageLength), func(b *testing.B) { // Generate messages for this scenario messages := make([]types.Message, scenario.messageCount) - messages[0] = types.Message{Role: "system", Content: "You are a helpful assistant."} + messages[0] = types.Message{Role: "system", Content: types.Content{Raw: "You are a helpful assistant."}} for i := 1; i < scenario.messageCount; i++ { role := "user" @@ -451,7 +451,7 @@ func BenchmarkPrefixPluginChatCompletionsStress(b *testing.B) { role = "assistant" } content := randomPrompt(scenario.messageLength) - messages[i] = types.Message{Role: role, Content: content} + messages[i] = types.Message{Role: role, Content: types.Content{Raw: content}} } pod := &types.PodMetrics{ diff --git a/pkg/epp/scheduling/types/types.go b/pkg/epp/scheduling/types/types.go index 2685a22d0..d50b41e44 100644 --- a/pkg/epp/scheduling/types/types.go +++ b/pkg/epp/scheduling/types/types.go @@ -17,7 +17,10 @@ limitations under the License. package types import ( + "encoding/json" + "errors" "fmt" + "strings" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" @@ -97,16 +100,75 @@ func (r *ChatCompletionsRequest) String() string { messagesLen := 0 for _, msg := range r.Messages { - messagesLen += len(msg.Content) + messagesLen += len(msg.Content.PlainText()) } - return fmt.Sprintf("{MessagesLength: %d}", messagesLen) } // Message represents a single message in a chat-completions request. type Message struct { - Role string - Content string // TODO: support multi-modal content + // Role is the message Role, optional values are 'user', 'assistant', ... + Role string `json:"role,omitempty"` + // Content defines text of this message + Content Content `json:"content,omitempty"` +} + +type Content struct { + Raw string + Structured []ContentBlock +} + +type ContentBlock struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + ImageURL ImageBlock `json:"image_url,omitempty"` +} + +type ImageBlock struct { + Url string `json:"url,omitempty"` +} + +// UnmarshalJSON allow use both format +func (mc *Content) UnmarshalJSON(data []byte) error { + // Raw format + var str string + if err := json.Unmarshal(data, &str); err == nil { + mc.Raw = str + return nil + } + + // Block format + var blocks []ContentBlock + if err := json.Unmarshal(data, &blocks); err == nil { + mc.Structured = blocks + return nil + } + + return errors.New("content format not supported") +} + +func (mc Content) MarshalJSON() ([]byte, error) { + if mc.Raw != "" { + return json.Marshal(mc.Raw) + } + if mc.Structured != nil { + return json.Marshal(mc.Structured) + } + return json.Marshal("") +} + +func (mc Content) PlainText() string { + if mc.Raw != "" { + return mc.Raw + } + var sb strings.Builder + for _, block := range mc.Structured { + if block.Type == "text" { + sb.WriteString(block.Text) + sb.WriteString(" ") + } + } + return sb.String() } type Pod interface { diff --git a/pkg/epp/util/request/body_test.go b/pkg/epp/util/request/body_test.go index 64ab6de11..afe47ceac 100644 --- a/pkg/epp/util/request/body_test.go +++ b/pkg/epp/util/request/body_test.go @@ -58,8 +58,58 @@ func TestExtractRequestData(t *testing.T) { want: &types.LLMRequestBody{ ChatCompletions: &types.ChatCompletionsRequest{ Messages: []types.Message{ - {Role: "system", Content: "this is a system message"}, - {Role: "user", Content: "hello"}, + {Role: "system", Content: types.Content{Raw: "this is a system message"}}, + {Role: "user", Content: types.Content{Raw: "hello"}}, + }, + }, + }, + }, + { + name: "chat completions request body with multi-modal content", + body: map[string]any{ + "model": "test", + "messages": []any{ + map[string]any{ + "role": "system", + "content": []map[string]any{ + { + "type": "text", + "text": "Describe this image in one sentence.", + }, + }, + }, + map[string]any{ + "role": "user", + "content": []map[string]any{ + { + "type": "image_url", + "image_url": map[string]any{ + "url": "https://example.com/images/dui.jpg.", + }, + }, + }, + }, + }, + }, + want: &types.LLMRequestBody{ + ChatCompletions: &types.ChatCompletionsRequest{ + Messages: []types.Message{ + {Role: "system", Content: types.Content{ + Structured: []types.ContentBlock{ + { + Text: "Describe this image in one sentence.", + Type: "text", + }, + }, + }}, + {Role: "user", Content: types.Content{ + Structured: []types.ContentBlock{ + { + Type: "image_url", + ImageURL: types.ImageBlock{Url: "https://example.com/images/dui.jpg."}, + }, + }, + }}, }, }, }, @@ -81,7 +131,7 @@ func TestExtractRequestData(t *testing.T) { }, want: &types.LLMRequestBody{ ChatCompletions: &types.ChatCompletionsRequest{ - Messages: []types.Message{{Role: "user", Content: "hello"}}, + Messages: []types.Message{{Role: "user", Content: types.Content{Raw: "hello"}}}, Tools: []any{map[string]any{"type": "function"}}, Documents: []any{map[string]any{"content": "doc"}}, ChatTemplate: "custom template",