Skip to content

Commit c7ca20a

Browse files
committed
update
Signed-off-by: yxia216 <[email protected]>
1 parent 57c42ea commit c7ca20a

File tree

19 files changed

+1432
-153
lines changed

19 files changed

+1432
-153
lines changed

internal/apischema/gcp/gcp.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,13 @@ type GenerateContentRequest struct {
3838
}
3939

4040
type EmbedContentRequest struct {
41-
// Contains the multipart content of a message.
42-
//
43-
// https://github.com/googleapis/go-genai/blob/6a8184fcaf8bf15f0c566616a7b356560309be9b/types.go#L858
44-
Contents []genai.Content `json:"contents"`
45-
// Tool details of a tool that the model may use to generate a response.
41+
// Content to be embedded. Only text content is supported for embeddings.
42+
Content *genai.Content `json:"content"`
4643

44+
// Optional configuration for the embedding request.
45+
// Uses the official genai library configuration structure.
4746
Config *genai.EmbedContentConfig `json:"config,omitempty"`
4847
}
48+
49+
// Note: We now use genai.EmbedContentResponse directly instead of defining our own.
50+
// This provides better compatibility and includes metadata like token usage.

internal/apischema/openai/openai.go

Lines changed: 59 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1523,6 +1523,9 @@ type EmbeddingCompletionRequest struct {
15231523
// User: A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
15241524
// Docs: https://platform.openai.com/docs/api-reference/embeddings/create#embeddings-create-user
15251525
User *string `json:"user,omitempty"`
1526+
1527+
// GCPVertexAIEmbeddingVendorFields configures the GCP VertexAI specific fields during schema translation.
1528+
*GCPVertexAIEmbeddingVendorFields `json:",inline,omitempty"`
15261529
}
15271530

15281531
// GetModel implements ModelName interface
@@ -1553,26 +1556,67 @@ type EmbeddingChatRequest struct {
15531556
// User: A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
15541557
// Docs: https://platform.openai.com/docs/api-reference/embeddings/create#embeddings-create-user
15551558
User *string `json:"user,omitempty"`
1559+
1560+
// GCPVertexAIEmbeddingVendorFields configures the GCP VertexAI specific fields during schema translation.
1561+
*GCPVertexAIEmbeddingVendorFields `json:",inline,omitempty"`
15561562
}
15571563

15581564
// GetModel implements ModelProvider interface
15591565
func (e *EmbeddingChatRequest) GetModel() string {
15601566
return e.Model
15611567
}
15621568

1563-
type EmbeddingRequest interface {
1564-
EmbeddingCompletionRequest | EmbeddingChatRequest
1569+
// EmbeddingRequest is a union type that can handle both EmbeddingCompletionRequest and EmbeddingChatRequest.
1570+
type EmbeddingRequest struct {
1571+
OfCompletion *EmbeddingCompletionRequest `json:",omitzero,inline"`
1572+
OfChat *EmbeddingChatRequest `json:",omitzero,inline"`
1573+
}
1574+
1575+
// UnmarshalJSON implements json.Unmarshaler to handle both EmbeddingCompletionRequest and EmbeddingChatRequest.
1576+
func (e *EmbeddingRequest) UnmarshalJSON(data []byte) error {
1577+
// Check for Messages field to distinguish EmbeddingChatRequest
1578+
messagesResult := gjson.GetBytes(data, "messages")
1579+
if messagesResult.Exists() {
1580+
var chatReq EmbeddingChatRequest
1581+
if err := json.Unmarshal(data, &chatReq); err != nil {
1582+
return err
1583+
}
1584+
e.OfChat = &chatReq
1585+
return nil
1586+
}
1587+
1588+
// Check for Input field to distinguish EmbeddingCompletionRequest
1589+
inputResult := gjson.GetBytes(data, "input")
1590+
if inputResult.Exists() {
1591+
var completionReq EmbeddingCompletionRequest
1592+
if err := json.Unmarshal(data, &completionReq); err != nil {
1593+
return err
1594+
}
1595+
e.OfCompletion = &completionReq
1596+
return nil
1597+
}
1598+
1599+
return errors.New("embedding request must have either 'input' field (EmbeddingCompletionRequest) or 'messages' field (EmbeddingChatRequest)")
15651600
}
15661601

1567-
// ModelName interface for types that can provide a model name
1568-
type ModelName interface {
1569-
GetModel() string
1602+
// MarshalJSON implements json.Marshaler.
1603+
func (e EmbeddingRequest) MarshalJSON() ([]byte, error) {
1604+
if e.OfCompletion != nil {
1605+
return json.Marshal(e.OfCompletion)
1606+
}
1607+
if e.OfChat != nil {
1608+
return json.Marshal(e.OfChat)
1609+
}
1610+
return nil, errors.New("no embedding request to marshal")
15701611
}
15711612

15721613
// GetModelFromEmbeddingRequest extracts the model name from any EmbeddingRequest type
1573-
func GetModelFromEmbeddingRequest[T EmbeddingRequest](req *T) string {
1574-
if mp, ok := any(*req).(ModelName); ok {
1575-
return mp.GetModel()
1614+
func GetModelFromEmbeddingRequest(req *EmbeddingRequest) string {
1615+
if req.OfCompletion != nil {
1616+
return req.OfCompletion.GetModel()
1617+
}
1618+
if req.OfChat != nil {
1619+
return req.OfChat.GetModel()
15761620
}
15771621
return ""
15781622
}
@@ -1650,6 +1694,13 @@ type EmbeddingUsage struct {
16501694
TotalTokens int `json:"total_tokens"` //nolint:tagliatelle //follow openai api
16511695
}
16521696

1697+
// GCPVertexAIEmbeddingVendorFields contains GCP Vertex AI (Gemini) vendor-specific fields for embedding requests.
1698+
type GCPVertexAIEmbeddingVendorFields struct {
1699+
// Type of task for which the embedding will be used.
1700+
// https://docs.cloud.google.com/vertex-ai/generative-ai/docs/embeddings/task-types#supported_task_types
1701+
TaskType string `json:"task_type,omitempty"`
1702+
}
1703+
16531704
// JSONUNIXTime is a helper type to marshal/unmarshal time.Time UNIX timestamps.
16541705
type JSONUNIXTime time.Time
16551706

internal/tracing/openinference/openai/request_attrs.go

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -231,11 +231,29 @@ func buildEmbeddingsRequestAttributes(embRequest *openai.EmbeddingRequest, body
231231
}
232232

233233
if !config.HideLLMInvocationParameters {
234+
// Extract parameters from the union type
235+
var model string
236+
var encodingFormat *string
237+
var dimensions *int
238+
var user *string
239+
240+
if embRequest.OfCompletion != nil {
241+
model = embRequest.OfCompletion.Model
242+
encodingFormat = embRequest.OfCompletion.EncodingFormat
243+
dimensions = embRequest.OfCompletion.Dimensions
244+
user = embRequest.OfCompletion.User
245+
} else if embRequest.OfChat != nil {
246+
model = embRequest.OfChat.Model
247+
encodingFormat = embRequest.OfChat.EncodingFormat
248+
dimensions = embRequest.OfChat.Dimensions
249+
user = embRequest.OfChat.User
250+
}
251+
234252
params := embeddingsInvocationParameters{
235-
Model: embRequest.Model,
236-
EncodingFormat: embRequest.EncodingFormat,
237-
Dimensions: embRequest.Dimensions,
238-
User: embRequest.User,
253+
Model: model,
254+
EncodingFormat: encodingFormat,
255+
Dimensions: dimensions,
256+
User: user,
239257
}
240258
if invocationParamsJSON, err := json.Marshal(params); err == nil {
241259
attrs = append(attrs, attribute.String(openinference.EmbeddingInvocationParameters, string(invocationParamsJSON)))
@@ -250,16 +268,26 @@ func buildEmbeddingsRequestAttributes(embRequest *openai.EmbeddingRequest, body
250268
// 4. Azure deployments don't affect this (they only host OpenAI models with cl100k_base)
251269
// Following OpenInference spec guidance to only record human-readable text.
252270
if !config.HideInputs && !config.HideEmbeddingsText {
253-
switch input := embRequest.Input.Value.(type) {
254-
case string:
255-
attrs = append(attrs, attribute.String(openinference.EmbeddingTextAttribute(0), input))
256-
case []string:
257-
for i, text := range input {
258-
attrs = append(attrs, attribute.String(openinference.EmbeddingTextAttribute(i), text))
271+
var inputValue any
272+
if embRequest.OfCompletion != nil {
273+
inputValue = embRequest.OfCompletion.Input.Value
274+
} else if embRequest.OfChat != nil {
275+
// For chat requests, we'll extract text from messages
276+
inputValue = "chat_messages" // Simplified - could be enhanced to extract actual text
277+
}
278+
279+
if inputValue != nil {
280+
switch input := inputValue.(type) {
281+
case string:
282+
attrs = append(attrs, attribute.String(openinference.EmbeddingTextAttribute(0), input))
283+
case []string:
284+
for i, text := range input {
285+
attrs = append(attrs, attribute.String(openinference.EmbeddingTextAttribute(i), text))
286+
}
287+
// Token inputs are not recorded to reduce span size.
288+
case []int64:
289+
case [][]int64:
259290
}
260-
// Token inputs are not recorded to reduce span size.
261-
case []int64:
262-
case [][]int64:
263291
}
264292
}
265293

internal/translator/openai_azureopenai_embeddings.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@ type openAIToAzureOpenAITranslatorV1Embedding struct {
3333
}
3434

3535
// RequestBody implements [OpenAIEmbeddingTranslator.RequestBody].
36-
func (o *openAIToAzureOpenAITranslatorV1Embedding) RequestBody(original []byte, req *openai.EmbeddingCompletionRequest, onRetry bool) (
36+
func (o *openAIToAzureOpenAITranslatorV1Embedding) RequestBody(original []byte, req *openai.EmbeddingRequest, onRetry bool) (
3737
newHeaders []internalapi.Header, newBody []byte, err error,
3838
) {
39-
modelName := req.Model
39+
modelName := openai.GetModelFromEmbeddingRequest(req)
4040
if o.modelNameOverride != "" {
4141
// If modelName is set we override the model to be used for the request.
4242
newBody, err = sjson.SetBytesOptions(original, "model", o.modelNameOverride, sjsonOptions)

internal/translator/openai_embeddings.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ type openAIToOpenAITranslatorV1Embedding struct {
3636
}
3737

3838
// RequestBody implements [OpenAIEmbeddingTranslator.RequestBody].
39-
func (o *openAIToOpenAITranslatorV1Embedding) RequestBody(original []byte, _ *openai.EmbeddingCompletionRequest, onRetry bool) (
39+
func (o *openAIToOpenAITranslatorV1Embedding) RequestBody(original []byte, req *openai.EmbeddingRequest, onRetry bool) (
4040
newHeaders []internalapi.Header, newBody []byte, err error,
4141
) {
4242
if o.modelNameOverride != "" {

internal/translator/openai_gcpvertexai.go

Lines changed: 51 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -533,10 +533,9 @@ func (o *openAIToGCPVertexAITranslatorV1ChatCompletion) geminiResponseToOpenAIMe
533533
return openaiResp, nil
534534
}
535535

536-
// ResponseError implements [OpenAIChatCompletionTranslator.ResponseError].
537-
// Translate GCP Vertex AI exceptions to OpenAI error type.
538-
// GCP error responses typically contain JSON with error details or plain text error messages.
539-
func (o *openAIToGCPVertexAITranslatorV1ChatCompletion) ResponseError(respHeaders map[string]string, body io.Reader) (
536+
// convertGCPVertexAIErrorToOpenAI converts GCP Vertex AI error responses to OpenAI error format.
537+
// This is a shared function used by both chat completion and embedding translators.
538+
func convertGCPVertexAIErrorToOpenAI(respHeaders map[string]string, body io.Reader) (
540539
newHeaders []internalapi.Header, newBody []byte, err error,
541540
) {
542541
var buf []byte
@@ -545,8 +544,8 @@ func (o *openAIToGCPVertexAITranslatorV1ChatCompletion) ResponseError(respHeader
545544
return nil, nil, fmt.Errorf("failed to read error body: %w", err)
546545
}
547546

548-
// Assume all responses have a valid status code header.
549547
statusCode := respHeaders[statusHeaderName]
548+
contentType := respHeaders[contentTypeHeaderName]
550549

551550
openaiError := openai.Error{
552551
Type: "error",
@@ -556,19 +555,45 @@ func (o *openAIToGCPVertexAITranslatorV1ChatCompletion) ResponseError(respHeader
556555
},
557556
}
558557

559-
var gcpError gcpVertexAIError
560-
// Try to parse as GCP error response structure.
561-
if err = json.Unmarshal(buf, &gcpError); err == nil {
562-
errMsg := gcpError.Error.Message
563-
if len(gcpError.Error.Details) > 0 {
564-
// If details are present and not null, append them to the error message.
565-
errMsg = fmt.Sprintf("Error: %s\nDetails: %s", errMsg, string(gcpError.Error.Details))
566-
}
567-
openaiError.Error.Type = gcpError.Error.Status
568-
openaiError.Error.Message = errMsg
569-
} else {
570-
// If not JSON, read the raw body as the error message.
558+
// If the content type is not JSON, treat it as a generic error
559+
if contentType != "" && contentType != jsonContentType {
571560
openaiError.Error.Message = string(buf)
561+
} else {
562+
var gcpError gcpVertexAIError
563+
// Try to parse as GCP error response structure first
564+
if err = json.Unmarshal(buf, &gcpError); err == nil {
565+
errMsg := gcpError.Error.Message
566+
if len(gcpError.Error.Details) > 0 {
567+
// If details are present and not null, append them to the error message.
568+
errMsg = fmt.Sprintf("Error: %s\nDetails: %s", errMsg, string(gcpError.Error.Details))
569+
}
570+
openaiError.Error.Type = gcpError.Error.Status
571+
openaiError.Error.Message = errMsg
572+
} else {
573+
// Try to parse as generic JSON error format
574+
var genericError map[string]interface{}
575+
if err := json.Unmarshal(buf, &genericError); err == nil {
576+
// Extract error message from generic JSON error format
577+
var errorMessage string
578+
if errorField, exists := genericError["error"]; exists {
579+
if errorMap, ok := errorField.(map[string]interface{}); ok {
580+
if message, exists := errorMap["message"]; exists {
581+
if msgStr, ok := message.(string); ok {
582+
errorMessage = msgStr
583+
}
584+
}
585+
}
586+
}
587+
if errorMessage != "" {
588+
openaiError.Error.Message = errorMessage
589+
} else {
590+
openaiError.Error.Message = string(buf)
591+
}
592+
} else {
593+
// If not parseable as JSON, use raw body as the error message
594+
openaiError.Error.Message = string(buf)
595+
}
596+
}
572597
}
573598

574599
newBody, err = json.Marshal(openaiError)
@@ -581,3 +606,12 @@ func (o *openAIToGCPVertexAITranslatorV1ChatCompletion) ResponseError(respHeader
581606
}
582607
return
583608
}
609+
610+
// ResponseError implements [OpenAIChatCompletionTranslator.ResponseError].
611+
// Translate GCP Vertex AI exceptions to OpenAI error type.
612+
// GCP error responses typically contain JSON with error details or plain text error messages.
613+
func (o *openAIToGCPVertexAITranslatorV1ChatCompletion) ResponseError(respHeaders map[string]string, body io.Reader) (
614+
newHeaders []internalapi.Header, newBody []byte, err error,
615+
) {
616+
return convertGCPVertexAIErrorToOpenAI(respHeaders, body)
617+
}

0 commit comments

Comments
 (0)