Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion aigateway/handler/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -407,14 +407,14 @@ func (h *OpenAIHandlerImpl) Embedding(c *gin.Context) {
slog.InfoContext(c, "proxy embedding request to model endpoint", slog.Any("target", target), slog.Any("host", host),
slog.Any("user", username), slog.Any("model_id", modelID))
rp, _ := proxy.NewReverseProxy(target)
w := NewResponseWriterWrapperEmbedding(c.Writer)

tokenCounter := h.tokenCounterFactory.NewEmbedding(token.CreateParam{
Endpoint: target,
Host: host,
Model: modelName,
ImageID: model.ImageID,
})
w := NewResponseWriterWrapperEmbedding(c.Writer, tokenCounter)
tokenCounter.Input(req.Input)

rp.ServeHTTP(w, c.Request, "", host)
Expand Down
26 changes: 3 additions & 23 deletions aigateway/handler/response_writer_wrapper_embedding.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,29 +10,20 @@ import (

"github.com/openai/openai-go/v3"
"opencsg.com/csghub-server/aigateway/token"
"opencsg.com/csghub-server/builder/rpc"
)

type ResponseWriterWrapperEmbedding struct {
internalWritter http.ResponseWriter
modSvcClient rpc.ModerationSvcClient
tokenCounter *token.EmbeddingTokenCounter
tokenCounter token.EmbeddingTokenCounter
}

func NewResponseWriterWrapperEmbedding(internalWritter http.ResponseWriter) *ResponseWriterWrapperEmbedding {
func NewResponseWriterWrapperEmbedding(internalWritter http.ResponseWriter, tokenCounter token.EmbeddingTokenCounter) *ResponseWriterWrapperEmbedding {
return &ResponseWriterWrapperEmbedding{
internalWritter: internalWritter,
tokenCounter: tokenCounter,
}
}

func (rw *ResponseWriterWrapperEmbedding) WithModeration(modSvcClient rpc.ModerationSvcClient) {
rw.modSvcClient = modSvcClient
}

func (rw *ResponseWriterWrapperEmbedding) WithTokenCounter(counter *token.EmbeddingTokenCounter) {
rw.tokenCounter = counter
}

func (rw *ResponseWriterWrapperEmbedding) Header() http.Header {
return rw.internalWritter.Header()
}
Expand Down Expand Up @@ -66,14 +57,3 @@ func (rw *ResponseWriterWrapperEmbedding) Write(data []byte) (int, error) {

return rw.internalWritter.Write(data)
}

//TODO: moderate embedding request and generate sensitive response if needed
// func (rw *ResponseWriterWrapperEmbedding) generateSensitiveResp(originResp openai.CreateEmbeddingResponse) openai.CreateEmbeddingResponse {
// newResp := openai.CreateEmbeddingResponse{
// Data: nil,
// Model: originResp.Model,
// Object: originResp.Object,
// Usage: originResp.Usage,
// }
// return newResp
// }
87 changes: 83 additions & 4 deletions aigateway/handler/response_writer_wrapper_embedding_test.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
package handler

import (
"bytes"
"compress/gzip"
"context"
"encoding/json"
"net/http/httptest"
"testing"

"github.com/openai/openai-go/v3"
"github.com/stretchr/testify/assert"
mocktoken "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/aigateway/token"
"opencsg.com/csghub-server/aigateway/token"
)

Expand Down Expand Up @@ -71,20 +74,35 @@ func TestResponseWriterWrapperEmbedding_Write(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
// Create a test http.ResponseWriter
w := httptest.NewRecorder()
wrapper := NewResponseWriterWrapperEmbedding(w)

// Setup token counter if needed
var wrapper *ResponseWriterWrapperEmbedding
var counter *mocktoken.MockEmbeddingTokenCounter = nil
if tt.withCounter {
counter := token.NewEmbeddingTokenCounter(nil)
wrapper.WithTokenCounter(counter)
counter = mocktoken.NewMockEmbeddingTokenCounter(t)
wrapper = NewResponseWriterWrapperEmbedding(w, counter)
} else {
wrapper = NewResponseWriterWrapperEmbedding(w, nil)
}

// Prepare test data
var data []byte
var err error
if tt.name == "invalid json data" {
data = []byte(`invalid json`)
} else {
} else if tt.withCounter {
data, err = json.Marshal(tt.response)
assert.NoError(t, err)
var expectResp openai.CreateEmbeddingResponse
err = json.Unmarshal(data, &expectResp)
assert.NoError(t, err)
counter.EXPECT().Embedding(expectResp.Usage).Return().Once()
counter.EXPECT().Usage(context.Background()).Return(&token.Usage{
TotalTokens: expectResp.Usage.TotalTokens,
PromptTokens: expectResp.Usage.PromptTokens,
CompletionTokens: 0,
}, nil).Once()
} else if !tt.withCounter {
data, err = json.Marshal(tt.response)
assert.NoError(t, err)
}
Expand Down Expand Up @@ -112,3 +130,64 @@ func TestResponseWriterWrapperEmbedding_Write(t *testing.T) {
})
}
}

func TestResponseWriterWrapperEmbedding_Write_Gzip(t *testing.T) {
// Create a test http.ResponseWriter
w := httptest.NewRecorder()

// Setup token counter
counter := mocktoken.NewMockEmbeddingTokenCounter(t)
wrapper := NewResponseWriterWrapperEmbedding(w, counter)

// Create embedding response
response := openai.CreateEmbeddingResponse{
Object: "embedding",
Data: []openai.Embedding{
{
Object: "embedding",
Embedding: []float64{0.1, 0.2, 0.3},
Index: 0,
},
},
Model: "text-embedding-ada-002",
Usage: openai.CreateEmbeddingResponseUsage{
PromptTokens: 10,
TotalTokens: 10,
},
}

// Marshal response to JSON
jsonData, err := json.Marshal(response)
assert.NoError(t, err)

// Create gzip compressed data
var gzippedData bytes.Buffer
gzipWriter := gzip.NewWriter(&gzippedData)
_, err = gzipWriter.Write(jsonData)
assert.NoError(t, err)
gzipWriter.Close()
var expectResp openai.CreateEmbeddingResponse
err = json.Unmarshal(jsonData, &expectResp)
assert.NoError(t, err)
// Set expectations for token counter
counter.EXPECT().Embedding(expectResp.Usage).Return().Once()
counter.EXPECT().Usage(context.Background()).Return(&token.Usage{
TotalTokens: expectResp.Usage.TotalTokens,
PromptTokens: expectResp.Usage.PromptTokens,
CompletionTokens: 0,
}, nil).Once()

// Execute Write method with gzipped data
n, err := wrapper.Write(gzippedData.Bytes())

// Verify results
assert.NoError(t, err)
assert.Equal(t, len(gzippedData.Bytes()), n)
assert.Equal(t, gzippedData.Bytes(), w.Body.Bytes())

// Verify token counter was called with correct usage data
usage, err := wrapper.tokenCounter.Usage(context.Background())
assert.NoError(t, err)
assert.Equal(t, response.Usage.TotalTokens, usage.TotalTokens)
assert.Equal(t, response.Usage.PromptTokens, usage.PromptTokens)
}
Loading
Loading