Skip to content
52 changes: 52 additions & 0 deletions pkg/plugins/datalayer/models/datasource_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Package models
package models

import (
"context"
"testing"
"time"

"github.com/stretchr/testify/assert"
"k8s.io/apimachinery/pkg/types"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer/http"
fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer"
)

func TestDatasource(t *testing.T) {
source := http.NewHTTPDataSource("https", "/models", true, ModelsDataSourceType,
"models-data-source", parseModels, ModelsResponseType)
extractor, err := NewModelExtractor()
assert.Nil(t, err, "failed to create extractor")

dsType := source.TypedName().Type
assert.Equal(t, ModelsDataSourceType, dsType)

err = source.AddExtractor(extractor)
assert.Nil(t, err, "failed to add extractor")

err = source.AddExtractor(extractor)
assert.NotNil(t, err, "expected to fail to add the same extractor twice")

extractors := source.Extractors()
assert.Len(t, extractors, 1)
assert.Equal(t, extractor.TypedName().String(), extractors[0])

err = datalayer.RegisterSource(source)
assert.Nil(t, err, "failed to register")

ctx := context.Background()
factory := datalayer.NewEndpointFactory([]fwkdl.DataSource{source}, 100*time.Millisecond)
pod := &fwkdl.EndpointMetadata{
NamespacedName: types.NamespacedName{
Name: "pod1",
Namespace: "default",
},
Address: "1.2.3.4:5678",
}
endpoint := factory.NewEndpoint(ctx, pod, nil)
assert.NotNil(t, endpoint, "failed to create endpoint")

err = source.Collect(ctx, endpoint)
assert.NotNil(t, err, "expected to fail to collect metrics")
}
96 changes: 96 additions & 0 deletions pkg/plugins/datalayer/models/extractor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package models

import (
"context"
"fmt"
"reflect"
"strings"

fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer"
fwkplugin "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin"
)

const modelsAttributeKey = "/v1/models"

// ModelInfoCollection defines models' data returned from /v1/models API
type ModelInfoCollection []ModelInfo

// ModelInfo defines model's data returned from /v1/models API
type ModelInfo struct {
ID string `json:"id"`
Parent string `json:"parent,omitempty"`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parent field is not part of OpenAI standardization.
it's specific to vllm and might not work with other model servers.
I also don't think it's used (or should be used) anywhere.
I recommend removing this field.

OpenAI standard here:
https://platform.openai.com/docs/api-reference/models/list

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few comments

  • If not present, the omitempty kicks in so I don't see the downside of having it.
  • For use cases that need the parent information for Base/LoRA relations, if it is not provided by model extraction then one must assume the base model name is provided elsewhere. There is currently no other source of truth...

I think it is fine to rely on vLLM specific for that.

  1. It can be treated as part of the "contract" (same as the case when other model servers are expected to provide the MSP metrics even if by a different name).
  2. configuration of data sources is per EPP so you can always not enable this for other model servers . This is valid usage as long as we use homogeneous model server in a pool (other code breaks as well when this is not the case...)

}

// String returns a string representation of the model info
func (m *ModelInfo) String() string {
return fmt.Sprintf("%+v", *m)
}

// Clone returns a full copy of the object
func (m ModelInfoCollection) Clone() fwkdl.Cloneable {
if m == nil {
return nil
}
clone := make([]ModelInfo, len(m))
copy(clone, m)
return (*ModelInfoCollection)(&clone)
}

func (m ModelInfoCollection) String() string {
if m == nil {
return ""
}
parts := make([]string, len(m))
for i, p := range m {
parts[i] = p.String()
}
return "[" + strings.Join(parts, ", ") + "]"
}

// ModelResponse is the response from /v1/models API
type ModelResponse struct {
Object string `json:"object"`
Data []ModelInfo `json:"data"`
}

// ModelsResponseType is the type of models response
var (
ModelsResponseType = reflect.TypeOf(ModelResponse{})
)

// ModelExtractor implements the models extraction.
type ModelExtractor struct {
typedName fwkplugin.TypedName
}

// NewModelExtractor returns a new model extractor.
func NewModelExtractor() (*ModelExtractor, error) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: at least in theory, the plugin could have a name...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ModelExtractor is a plugin. A plugin has a type and an optional name.
The code does not support setting a plugin name and it should.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is the WithName() method now

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks.
I was also thinking NewModelExtractor() should be extended with a name string parameter. If empty it is set to the type and WithName() is called internally.
I think that that would have been more consistent with other plugins.

return &ModelExtractor{
typedName: fwkplugin.TypedName{
Type: ModelsExtractorType,
Name: ModelsExtractorType,
},
}, nil
}

// TypedName returns the type and name of the ModelExtractor.
func (me *ModelExtractor) TypedName() fwkplugin.TypedName {
return me.typedName
}

// ExpectedInputType defines the type expected by ModelExtractor.
func (me *ModelExtractor) ExpectedInputType() reflect.Type {
return ModelsResponseType
}

// Extract transforms the data source output into a concrete attribute that
// is stored on the given endpoint.
func (me *ModelExtractor) Extract(_ context.Context, data any, ep fwkdl.Endpoint) error {
models, ok := data.(*ModelResponse)
if !ok {
return fmt.Errorf("unexpected input in Extract: %T", data)
}

ep.GetAttributes().Put(modelsAttributeKey, ModelInfoCollection(models.Data))
return nil
}
113 changes: 113 additions & 0 deletions pkg/plugins/datalayer/models/extractor_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
package models

import (
"context"
"testing"

"github.com/google/go-cmp/cmp"

fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer"
)

func TestExtractorExtract(t *testing.T) {
ctx := context.Background()

extractor, err := NewModelExtractor()
if err != nil {
t.Fatalf("failed to create extractor: %v", err)
}

if exType := extractor.TypedName().Type; exType == "" {
t.Error("empty extractor type")
}

if exName := extractor.TypedName().Name; exName == "" {
t.Error("empty extractor name")
}

if inputType := extractor.ExpectedInputType(); inputType != ModelsResponseType {
t.Errorf("incorrect expected input type: %v", inputType)
}

ep := fwkdl.NewEndpoint(nil, nil)
if ep == nil {
t.Fatal("expected non-nil endpoint")
}

model := "food-review"

tests := []struct {
name string
data any
wantErr bool
updated bool // whether metrics are expected to change
}{
{
name: "nil data",
data: nil,
wantErr: true,
updated: false,
},
{
name: "empty ModelsResponse",
data: &ModelResponse{},
wantErr: false,
updated: false,
},
{
name: "valid models response",
data: &ModelResponse{
Object: "list",
Data: []ModelInfo{
{
ID: model,
},
{
ID: "lora1",
Parent: model,
},
},
},
wantErr: false,
updated: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
defer func() {
if r := recover(); r != nil {
t.Errorf("Extract panicked: %v", r)
}
}()

attr := ep.GetAttributes()
before, ok := attr.Get(modelsAttributeKey)
if ok && before != nil {
t.Error("expected empty attributes")
}
err := extractor.Extract(ctx, tt.data, ep)
after, ok := attr.Get(modelsAttributeKey)
if !ok && tt.updated {
t.Error("expected updated attributes")
}

if tt.wantErr && err == nil {
t.Errorf("expected error but got nil")
}
if !tt.wantErr && err != nil {
t.Errorf("unexpected error: %v", err)
}

if tt.updated {
if diff := cmp.Diff(before, after); diff == "" {
t.Errorf("expected models to be updated, but no change detected")
}
} else {
if diff := cmp.Diff(before, after); diff != "" {
t.Errorf("expected no models update, but got changes:\n%s", diff)
}
}
})
}
}
67 changes: 67 additions & 0 deletions pkg/plugins/datalayer/models/factories.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package models

import (
"encoding/json"
"fmt"
"io"

"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer/http"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin"
)

const (
// ModelsDataSourceType is models data source type
ModelsDataSourceType = "models-data-source"
// ModelsExtractorType is models extractor type
ModelsExtractorType = "model-server-protocol-models"
)

// Configuration parameters for models data source.
type modelsDatasourceParams struct {
// Scheme defines the protocol scheme used in models retrieval (e.g., "http").
Scheme string `json:"scheme"`
// Path defines the URL path used in models retrieval (e.g., "/v1/models").
Path string `json:"path"`
// InsecureSkipVerify defines whether model server certificate should be verified or not.
InsecureSkipVerify bool `json:"insecureSkipVerify"`
}

// ModelDataSourceFactory is a factory function used to instantiate data layer's
// models data source plugins specified in a configuration.
func ModelDataSourceFactory(name string, parameters json.RawMessage, _ plugin.Handle) (plugin.Plugin, error) {
cfg := defaultDataSourceConfigParams()
if parameters != nil { // overlay the defaults with configured values
if err := json.Unmarshal(parameters, cfg); err != nil {
return nil, err
}
}

ds := http.NewHTTPDataSource(cfg.Scheme, cfg.Path, cfg.InsecureSkipVerify, ModelsDataSourceType,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q; does NewHTTPDataSource validate the scheme?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, there is only a check if it's https

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we use the scheme passed in by the user it should at least sanitize it to ensure it's one one of a known set of acceptable values (e.g., "http" and "https").
Can be in this PR or separate adding scheme validation to the HTTPDataSource

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks.
Please open a tracking issue to move this check into HTTPDataSource in GAIE. It should not be up to each data source, IMO.

name, parseModels, ModelsResponseType)
return ds, nil
}

// ModelServerExtractorFactory is a factory function used to instantiate data layer's models
// Extractor plugins specified in a configuration.
func ModelServerExtractorFactory(name string, _ json.RawMessage, _ plugin.Handle) (plugin.Plugin, error) {
extractor, err := NewModelExtractor()
if err != nil {
return nil, err
}
extractor.typedName.Name = name
return extractor, nil
}

func defaultDataSourceConfigParams() *modelsDatasourceParams {
return &modelsDatasourceParams{Scheme: "http", Path: "/v1/models", InsecureSkipVerify: true}
}

func parseModels(data io.Reader) (any, error) {
body, err := io.ReadAll(data)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %v", err)
}
var modelsResponse ModelResponse
err = json.Unmarshal(body, &modelsResponse)
return &modelsResponse, err
}
3 changes: 3 additions & 0 deletions pkg/plugins/register.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package plugins

import (
"github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/datalayer/models"
"github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/filter"
prerequest "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/pre-request"
"github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/profile"
Expand All @@ -22,4 +23,6 @@ func RegisterAllPlugins() {
plugin.Register(scorer.SessionAffinityType, scorer.SessionAffinityFactory)
plugin.Register(scorer.ActiveRequestType, scorer.ActiveRequestFactory)
plugin.Register(scorer.NoHitLRUType, scorer.NoHitLRUFactory)
plugin.Register(models.ModelsDataSourceType, models.ModelDataSourceFactory)
plugin.Register(models.ModelsExtractorType, models.ModelServerExtractorFactory)
}