-
Notifications
You must be signed in to change notification settings - Fork 132
Models extractor #553
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Models extractor #553
Changes from 9 commits
4bb0dfc
8b6e87e
4f37fa5
8e970f1
e6e9852
4628de9
55996a2
f3f582a
c0aac23
8eaf3ce
7d16b18
9c510f8
5454388
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,52 @@ | ||
| // Package models | ||
| package models | ||
|
|
||
| import ( | ||
| "context" | ||
| "testing" | ||
| "time" | ||
|
|
||
| "github.com/stretchr/testify/assert" | ||
| "k8s.io/apimachinery/pkg/types" | ||
| "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" | ||
| "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer/http" | ||
| fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer" | ||
| ) | ||
|
|
||
| func TestDatasource(t *testing.T) { | ||
| source := http.NewHTTPDataSource("https", "/models", true, ModelsDataSourceType, | ||
| "models-data-source", parseModels, ModelsResponseType) | ||
| extractor, err := NewModelExtractor() | ||
| assert.Nil(t, err, "failed to create extractor") | ||
|
|
||
| dsType := source.TypedName().Type | ||
| assert.Equal(t, ModelsDataSourceType, dsType) | ||
|
|
||
| err = source.AddExtractor(extractor) | ||
| assert.Nil(t, err, "failed to add extractor") | ||
|
|
||
| err = source.AddExtractor(extractor) | ||
| assert.NotNil(t, err, "expected to fail to add the same extractor twice") | ||
|
|
||
| extractors := source.Extractors() | ||
| assert.Len(t, extractors, 1) | ||
| assert.Equal(t, extractor.TypedName().String(), extractors[0]) | ||
|
|
||
| err = datalayer.RegisterSource(source) | ||
| assert.Nil(t, err, "failed to register") | ||
|
|
||
| ctx := context.Background() | ||
| factory := datalayer.NewEndpointFactory([]fwkdl.DataSource{source}, 100*time.Millisecond) | ||
elevran marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| pod := &fwkdl.EndpointMetadata{ | ||
| NamespacedName: types.NamespacedName{ | ||
| Name: "pod1", | ||
| Namespace: "default", | ||
| }, | ||
| Address: "1.2.3.4:5678", | ||
| } | ||
| endpoint := factory.NewEndpoint(ctx, pod, nil) | ||
| assert.NotNil(t, endpoint, "failed to create endpoint") | ||
|
|
||
| err = source.Collect(ctx, endpoint) | ||
| assert.NotNil(t, err, "expected to fail to collect metrics") | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,96 @@ | ||
| package models | ||
|
|
||
| import ( | ||
| "context" | ||
| "fmt" | ||
| "reflect" | ||
| "strings" | ||
|
|
||
| fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer" | ||
| fwkplugin "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin" | ||
| ) | ||
|
|
||
| const modelsAttributeKey = "/v1/models" | ||
|
|
||
| // ModelInfoCollection defines models' data returned from /v1/models API | ||
| type ModelInfoCollection []ModelInfo | ||
|
|
||
| // ModelInfo defines model's data returned from /v1/models API | ||
| type ModelInfo struct { | ||
| ID string `json:"id"` | ||
| Parent string `json:"parent,omitempty"` | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. parent field is not part of OpenAI standardization. OpenAI standard here:
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A few comments
I think it is fine to rely on vLLM specific for that.
|
||
| } | ||
|
|
||
| // String returns a string representation of the model info | ||
| func (m *ModelInfo) String() string { | ||
| return fmt.Sprintf("%+v", *m) | ||
| } | ||
|
|
||
| // Clone returns a full copy of the object | ||
| func (m ModelInfoCollection) Clone() fwkdl.Cloneable { | ||
| if m == nil { | ||
| return nil | ||
| } | ||
| clone := make([]ModelInfo, len(m)) | ||
| copy(clone, m) | ||
| return (*ModelInfoCollection)(&clone) | ||
| } | ||
|
|
||
| func (m ModelInfoCollection) String() string { | ||
| if m == nil { | ||
| return "" | ||
elevran marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
| parts := make([]string, len(m)) | ||
| for i, p := range m { | ||
| parts[i] = p.String() | ||
| } | ||
| return "[" + strings.Join(parts, ", ") + "]" | ||
| } | ||
|
|
||
| // ModelResponse is the response from /v1/models API | ||
| type ModelResponse struct { | ||
| Object string `json:"object"` | ||
| Data []ModelInfo `json:"data"` | ||
| } | ||
|
|
||
| // ModelsResponseType is the type of models response | ||
| var ( | ||
| ModelsResponseType = reflect.TypeOf(ModelResponse{}) | ||
| ) | ||
|
|
||
| // ModelExtractor implements the models extraction. | ||
| type ModelExtractor struct { | ||
| typedName fwkplugin.TypedName | ||
| } | ||
|
|
||
| // NewModelExtractor returns a new model extractor. | ||
| func NewModelExtractor() (*ModelExtractor, error) { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: at least in theory, the plugin could have a name...
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What do you mean?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ModelExtractor is a plugin. A plugin has a type and an optional name.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is the WithName() method now
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks. |
||
| return &ModelExtractor{ | ||
| typedName: fwkplugin.TypedName{ | ||
| Type: ModelsExtractorType, | ||
| Name: ModelsExtractorType, | ||
| }, | ||
| }, nil | ||
| } | ||
|
|
||
| // TypedName returns the type and name of the ModelExtractor. | ||
| func (me *ModelExtractor) TypedName() fwkplugin.TypedName { | ||
| return me.typedName | ||
| } | ||
|
|
||
| // ExpectedInputType defines the type expected by ModelExtractor. | ||
| func (me *ModelExtractor) ExpectedInputType() reflect.Type { | ||
| return ModelsResponseType | ||
| } | ||
|
|
||
| // Extract transforms the data source output into a concrete attribute that | ||
| // is stored on the given endpoint. | ||
| func (me *ModelExtractor) Extract(_ context.Context, data any, ep fwkdl.Endpoint) error { | ||
| models, ok := data.(*ModelResponse) | ||
| if !ok { | ||
| return fmt.Errorf("unexpected input in Extract: %T", data) | ||
| } | ||
|
|
||
| ep.GetAttributes().Put(modelsAttributeKey, ModelInfoCollection(models.Data)) | ||
| return nil | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,113 @@ | ||
| package models | ||
|
|
||
| import ( | ||
| "context" | ||
| "testing" | ||
|
|
||
| "github.com/google/go-cmp/cmp" | ||
|
|
||
| fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer" | ||
| ) | ||
|
|
||
| func TestExtractorExtract(t *testing.T) { | ||
| ctx := context.Background() | ||
|
|
||
| extractor, err := NewModelExtractor() | ||
| if err != nil { | ||
| t.Fatalf("failed to create extractor: %v", err) | ||
| } | ||
|
|
||
| if exType := extractor.TypedName().Type; exType == "" { | ||
| t.Error("empty extractor type") | ||
| } | ||
|
|
||
| if exName := extractor.TypedName().Name; exName == "" { | ||
| t.Error("empty extractor name") | ||
| } | ||
|
|
||
| if inputType := extractor.ExpectedInputType(); inputType != ModelsResponseType { | ||
elevran marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| t.Errorf("incorrect expected input type: %v", inputType) | ||
| } | ||
|
|
||
| ep := fwkdl.NewEndpoint(nil, nil) | ||
| if ep == nil { | ||
| t.Fatal("expected non-nil endpoint") | ||
| } | ||
|
|
||
| model := "food-review" | ||
|
|
||
| tests := []struct { | ||
| name string | ||
| data any | ||
| wantErr bool | ||
| updated bool // whether metrics are expected to change | ||
| }{ | ||
| { | ||
| name: "nil data", | ||
| data: nil, | ||
| wantErr: true, | ||
| updated: false, | ||
| }, | ||
| { | ||
| name: "empty ModelsResponse", | ||
| data: &ModelResponse{}, | ||
| wantErr: false, | ||
| updated: false, | ||
| }, | ||
| { | ||
| name: "valid models response", | ||
| data: &ModelResponse{ | ||
| Object: "list", | ||
| Data: []ModelInfo{ | ||
| { | ||
| ID: model, | ||
| }, | ||
| { | ||
| ID: "lora1", | ||
| Parent: model, | ||
| }, | ||
| }, | ||
| }, | ||
| wantErr: false, | ||
| updated: true, | ||
| }, | ||
| } | ||
|
|
||
| for _, tt := range tests { | ||
| t.Run(tt.name, func(t *testing.T) { | ||
| defer func() { | ||
| if r := recover(); r != nil { | ||
| t.Errorf("Extract panicked: %v", r) | ||
| } | ||
| }() | ||
|
|
||
| attr := ep.GetAttributes() | ||
| before, ok := attr.Get(modelsAttributeKey) | ||
| if ok && before != nil { | ||
| t.Error("expected empty attributes") | ||
| } | ||
| err := extractor.Extract(ctx, tt.data, ep) | ||
| after, ok := attr.Get(modelsAttributeKey) | ||
| if !ok && tt.updated { | ||
| t.Error("expected updated attributes") | ||
| } | ||
|
|
||
| if tt.wantErr && err == nil { | ||
| t.Errorf("expected error but got nil") | ||
| } | ||
| if !tt.wantErr && err != nil { | ||
| t.Errorf("unexpected error: %v", err) | ||
| } | ||
|
|
||
| if tt.updated { | ||
| if diff := cmp.Diff(before, after); diff == "" { | ||
| t.Errorf("expected models to be updated, but no change detected") | ||
| } | ||
| } else { | ||
| if diff := cmp.Diff(before, after); diff != "" { | ||
| t.Errorf("expected no models update, but got changes:\n%s", diff) | ||
| } | ||
| } | ||
| }) | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,67 @@ | ||
| package models | ||
|
|
||
| import ( | ||
| "encoding/json" | ||
| "fmt" | ||
| "io" | ||
|
|
||
| "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer/http" | ||
| "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin" | ||
| ) | ||
|
|
||
| const ( | ||
| // ModelsDataSourceType is models data source type | ||
| ModelsDataSourceType = "models-data-source" | ||
| // ModelsExtractorType is models extractor type | ||
| ModelsExtractorType = "model-server-protocol-models" | ||
| ) | ||
|
|
||
| // Configuration parameters for models data source. | ||
| type modelsDatasourceParams struct { | ||
| // Scheme defines the protocol scheme used in models retrieval (e.g., "http"). | ||
| Scheme string `json:"scheme"` | ||
| // Path defines the URL path used in models retrieval (e.g., "/v1/models"). | ||
| Path string `json:"path"` | ||
| // InsecureSkipVerify defines whether model server certificate should be verified or not. | ||
| InsecureSkipVerify bool `json:"insecureSkipVerify"` | ||
| } | ||
|
|
||
| // ModelDataSourceFactory is a factory function used to instantiate data layer's | ||
| // models data source plugins specified in a configuration. | ||
| func ModelDataSourceFactory(name string, parameters json.RawMessage, _ plugin.Handle) (plugin.Plugin, error) { | ||
| cfg := defaultDataSourceConfigParams() | ||
| if parameters != nil { // overlay the defaults with configured values | ||
| if err := json.Unmarshal(parameters, cfg); err != nil { | ||
| return nil, err | ||
| } | ||
| } | ||
|
|
||
| ds := http.NewHTTPDataSource(cfg.Scheme, cfg.Path, cfg.InsecureSkipVerify, ModelsDataSourceType, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Q; does
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, there is only a check if it's https
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we use the scheme passed in by the user it should at least sanitize it to ensure it's one one of a known set of acceptable values (e.g., "http" and "https").
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks. |
||
| name, parseModels, ModelsResponseType) | ||
| return ds, nil | ||
| } | ||
|
|
||
| // ModelServerExtractorFactory is a factory function used to instantiate data layer's models | ||
| // Extractor plugins specified in a configuration. | ||
| func ModelServerExtractorFactory(name string, _ json.RawMessage, _ plugin.Handle) (plugin.Plugin, error) { | ||
| extractor, err := NewModelExtractor() | ||
| if err != nil { | ||
| return nil, err | ||
| } | ||
| extractor.typedName.Name = name | ||
elevran marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return extractor, nil | ||
| } | ||
|
|
||
| func defaultDataSourceConfigParams() *modelsDatasourceParams { | ||
| return &modelsDatasourceParams{Scheme: "http", Path: "/v1/models", InsecureSkipVerify: true} | ||
| } | ||
|
|
||
| func parseModels(data io.Reader) (any, error) { | ||
| body, err := io.ReadAll(data) | ||
| if err != nil { | ||
| return nil, fmt.Errorf("failed to read response body: %v", err) | ||
| } | ||
| var modelsResponse ModelResponse | ||
| err = json.Unmarshal(body, &modelsResponse) | ||
| return &modelsResponse, err | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.