From 81c3486e40f339ba99dc61b2841f3ff5663646fc Mon Sep 17 00:00:00 2001 From: Dev Agent Date: Thu, 25 Dec 2025 09:33:20 +0000 Subject: [PATCH] Feat aigateway record usage support filter --- aigateway/component/openai.go | 108 +++++- aigateway/component/openai_test.go | 523 ++++++++++++++++++++++++++++- aigateway/handler/openai.go | 5 +- aigateway/types/openai.go | 1 + common/types/accounting.go | 10 + 5 files changed, 633 insertions(+), 14 deletions(-) diff --git a/aigateway/component/openai.go b/aigateway/component/openai.go index 71a6bd9ee..903886a06 100644 --- a/aigateway/component/openai.go +++ b/aigateway/component/openai.go @@ -33,6 +33,7 @@ type OpenAIComponent interface { type openaiComponentImpl struct { userStore database.UserStore + organStore database.OrgStore deployStore database.DeployTaskStore eventPub *event.EventPublisher extllmStore database.LLMConfigStore @@ -51,6 +52,7 @@ func NewOpenAIComponentFromConfig(config *config.Config) (OpenAIComponent, error } return &openaiComponentImpl{ userStore: database.NewUserStore(), + organStore: database.NewOrgStore(), deployStore: database.NewDeployTaskStore(), eventPub: &event.DefaultEventPublisher, extllmStore: database.NewLLMConfigStore(config), @@ -108,6 +110,7 @@ func (m *openaiComponentImpl) getCSGHubModels(c context.Context, userName string }, InternalModelInfo: types.InternalModelInfo{ CSGHubModelID: deploy.Repository.Path, + OwnerUUID: deploy.User.UUID, ClusterID: deploy.ClusterID, SvcName: deploy.SvcName, SvcType: deploy.Type, @@ -284,23 +287,67 @@ func (m *openaiComponentImpl) RecordUsage(c context.Context, userUUID string, mo var tokenUsageExtra = struct { PromptTokenNum string `json:"prompt_token_num"` CompletionTokenNum string `json:"completion_token_num"` + // 0: external, 1: owner is user, 2: other user is inference, 3: serverless + OwnerType commontypes.TokenUsageType }{ PromptTokenNum: fmt.Sprintf("%d", usage.PromptTokens), CompletionTokenNum: fmt.Sprintf("%d", usage.CompletionTokens), } + if model.CSGHubModelID != "" && model.Provider != "" { + slog.WarnContext(c, "bad model info, both csghub model id and external model provider is set", + slog.Any("model info", model)) + } + if model.CSGHubModelID == "" && model.Provider == "" { + slog.WarnContext(c, "bad model info, both csghub model id and external model provider is not set", + slog.Any("model info", model)) + } + if model.CSGHubModelID != "" { + switch model.SvcType { + case commontypes.ServerlessType: + tokenUsageExtra.OwnerType = commontypes.CSGHubServerlessInference + case commontypes.InferenceType: + if model.OwnerUUID == userUUID { + tokenUsageExtra.OwnerType = commontypes.CSGHubUserDeployedInference + } else { + belong, err := m.checkOrganization(c, userUUID, model.OwnerUUID) + if err != nil { + return fmt.Errorf("failed to check organization,error:%w", err) + } + if belong { + tokenUsageExtra.OwnerType = commontypes.CSGHubOrganFellowDeployedInference + } else { + tokenUsageExtra.OwnerType = commontypes.CSGHubOtherDeployedInference + } + } + default: + slog.WarnContext(c, "bad model info, csghub model missing service type", + slog.Any("model info", model)) + } + } + if model.Provider != "" { + tokenUsageExtra.OwnerType = commontypes.ExternalInference + } + extraData, _ := json.Marshal(tokenUsageExtra) event := commontypes.MeteringEvent{ - Uuid: uuid.New(), - UserUUID: userUUID, - Value: usage.TotalTokens, - ValueType: commontypes.TokenNumberType, // count by token - Scene: getSceneFromSvcType(model.SvcType), - OpUID: "", - ResourceID: model.CSGHubModelID, - ResourceName: model.CSGHubModelID, - CustomerID: model.SvcName, - CreatedAt: time.Now(), - Extra: string(extraData), + Uuid: uuid.New(), + UserUUID: userUUID, + Value: usage.TotalTokens, + ValueType: commontypes.TokenNumberType, // count by token + Scene: int(commontypes.SceneModelServerless), + OpUID: "aigateway", + CreatedAt: time.Now(), + Extra: string(extraData), + } + if model.CSGHubModelID != "" { + event.ResourceID = model.CSGHubModelID + event.ResourceName = model.CSGHubModelID + event.CustomerID = model.SvcName + } + if model.Provider != "" { + event.ResourceID = model.ID + event.ResourceName = model.ID + event.CustomerID = model.Provider } eventData, _ := json.Marshal(event) err = m.eventPub.PublishMeteringEvent(eventData) @@ -312,3 +359,42 @@ func (m *openaiComponentImpl) RecordUsage(c context.Context, userUUID string, mo slog.Info("public token usage event success", "event", event) return nil } + +func (m *openaiComponentImpl) checkOrganization(c context.Context, userUUID string, ownerUUID string) (bool, error) { + user, err := m.userStore.FindByUUID(c, userUUID) + if err != nil { + slog.ErrorContext(c, "Failed to find user in db") + return false, err + } + owner, err := m.userStore.FindByUUID(c, ownerUUID) + if err != nil { + slog.ErrorContext(c, "Failed to find owner in db") + return false, err + } + userOrgs, err := m.organStore.GetUserBelongOrgs(c, user.ID) + if err != nil { + slog.ErrorContext(c, "Failed to find user organizations") + return false, err + } + if len(userOrgs) == 0 { + return false, nil + } + ownerOrgs, err := m.organStore.GetUserBelongOrgs(c, owner.ID) + if err != nil { + slog.ErrorContext(c, "Failed to find owner organizations") + return false, err + } + if len(ownerOrgs) == 0 { + return false, nil + } + userOrgansMap := make(map[int64]struct{}, len(userOrgs)) + for _, org := range userOrgs { + userOrgansMap[org.ID] = struct{}{} + } + for _, org := range ownerOrgs { + if _, ok := userOrgansMap[org.ID]; ok { + return true, nil + } + } + return false, nil +} diff --git a/aigateway/component/openai_test.go b/aigateway/component/openai_test.go index 97b0bfeff..16bb58a10 100644 --- a/aigateway/component/openai_test.go +++ b/aigateway/component/openai_test.go @@ -330,6 +330,228 @@ func TestGetSceneFromSvcType(t *testing.T) { } } +func TestOpenAIComponent_checkOrganization(t *testing.T) { + mockUserStore := mockdb.NewMockUserStore(t) + mockOrgStore := mockdb.NewMockOrgStore(t) + + comp := &openaiComponentImpl{ + userStore: mockUserStore, + organStore: mockOrgStore, + } + + t.Run("users belong to same organization - should return true", func(t *testing.T) { + ctx := context.Background() + userUUID := "user-uuid-123" + ownerUUID := "owner-uuid-456" + + user := &database.User{ + ID: 1, + UUID: userUUID, + } + owner := &database.User{ + ID: 2, + UUID: ownerUUID, + } + + org1 := database.Organization{ + ID: 100, + Name: "org1", + } + org2 := database.Organization{ + ID: 200, + Name: "org2", + } + + userOrgs := []database.Organization{org1, org2} + ownerOrgs := []database.Organization{org2, {ID: 300, Name: "org3"}} + + mockUserStore.EXPECT().FindByUUID(ctx, userUUID).Return(user, nil).Once() + mockUserStore.EXPECT().FindByUUID(ctx, ownerUUID).Return(owner, nil).Once() + mockOrgStore.EXPECT().GetUserBelongOrgs(ctx, user.ID).Return(userOrgs, nil).Once() + mockOrgStore.EXPECT().GetUserBelongOrgs(ctx, owner.ID).Return(ownerOrgs, nil).Once() + + result, err := comp.checkOrganization(ctx, userUUID, ownerUUID) + + assert.NoError(t, err) + assert.True(t, result, "Users should belong to same organization") + }) + + t.Run("users do not belong to same organization - should return false", func(t *testing.T) { + ctx := context.Background() + userUUID := "user-uuid-123" + ownerUUID := "owner-uuid-456" + + user := &database.User{ + ID: 1, + UUID: userUUID, + } + owner := &database.User{ + ID: 2, + UUID: ownerUUID, + } + + userOrgs := []database.Organization{ + {ID: 100, Name: "org1"}, + {ID: 200, Name: "org2"}, + } + ownerOrgs := []database.Organization{ + {ID: 300, Name: "org3"}, + {ID: 400, Name: "org4"}, + } + + mockUserStore.EXPECT().FindByUUID(ctx, userUUID).Return(user, nil).Once() + mockUserStore.EXPECT().FindByUUID(ctx, ownerUUID).Return(owner, nil).Once() + mockOrgStore.EXPECT().GetUserBelongOrgs(ctx, user.ID).Return(userOrgs, nil).Once() + mockOrgStore.EXPECT().GetUserBelongOrgs(ctx, owner.ID).Return(ownerOrgs, nil).Once() + + result, err := comp.checkOrganization(ctx, userUUID, ownerUUID) + + assert.NoError(t, err) + assert.False(t, result, "Users should not belong to same organization") + }) + + t.Run("user has no organizations - should return false", func(t *testing.T) { + ctx := context.Background() + userUUID := "user-uuid-123" + ownerUUID := "owner-uuid-456" + + user := &database.User{ + ID: 1, + UUID: userUUID, + } + owner := &database.User{ + ID: 2, + UUID: ownerUUID, + } + + userOrgs := []database.Organization{} + + mockUserStore.EXPECT().FindByUUID(ctx, userUUID).Return(user, nil).Once() + mockUserStore.EXPECT().FindByUUID(ctx, ownerUUID).Return(owner, nil).Once() + mockOrgStore.EXPECT().GetUserBelongOrgs(ctx, user.ID).Return(userOrgs, nil).Once() + + result, err := comp.checkOrganization(ctx, userUUID, ownerUUID) + + assert.NoError(t, err) + assert.False(t, result, "User with no organizations should not have access") + }) + + t.Run("owner has no organizations - should return false", func(t *testing.T) { + ctx := context.Background() + userUUID := "user-uuid-123" + ownerUUID := "owner-uuid-456" + + user := &database.User{ + ID: 1, + UUID: userUUID, + } + owner := &database.User{ + ID: 2, + UUID: ownerUUID, + } + + userOrgs := []database.Organization{ + {ID: 100, Name: "org1"}, + } + ownerOrgs := []database.Organization{} + + mockUserStore.EXPECT().FindByUUID(ctx, userUUID).Return(user, nil).Once() + mockUserStore.EXPECT().FindByUUID(ctx, ownerUUID).Return(owner, nil).Once() + mockOrgStore.EXPECT().GetUserBelongOrgs(ctx, user.ID).Return(userOrgs, nil).Once() + mockOrgStore.EXPECT().GetUserBelongOrgs(ctx, owner.ID).Return(ownerOrgs, nil).Once() + + result, err := comp.checkOrganization(ctx, userUUID, ownerUUID) + + assert.NoError(t, err) + assert.False(t, result, "Owner with no organizations should not grant access") + }) + + t.Run("user not found - should return false without error", func(t *testing.T) { + ctx := context.Background() + userUUID := "nonexistent-user" + ownerUUID := "owner-uuid-456" + + mockUserStore.EXPECT().FindByUUID(ctx, userUUID).Return(&database.User{}, errors.New("user not found")).Once() + + result, err := comp.checkOrganization(ctx, userUUID, ownerUUID) + + assert.Error(t, err) + assert.False(t, result, "Should return false when user is not found") + }) + + t.Run("owner not found - should return false without error", func(t *testing.T) { + ctx := context.Background() + userUUID := "user-uuid-123" + ownerUUID := "nonexistent-owner" + + user := &database.User{ + ID: 1, + UUID: userUUID, + } + + mockUserStore.EXPECT().FindByUUID(ctx, userUUID).Return(user, nil).Once() + mockUserStore.EXPECT().FindByUUID(ctx, ownerUUID).Return(&database.User{}, errors.New("owner not found")).Once() + + result, err := comp.checkOrganization(ctx, userUUID, ownerUUID) + + assert.Error(t, err) + assert.False(t, result, "Should return false when owner is not found") + }) + + t.Run("error getting user organizations - should return false without error", func(t *testing.T) { + ctx := context.Background() + userUUID := "user-uuid-123" + ownerUUID := "owner-uuid-456" + + user := &database.User{ + ID: 1, + UUID: userUUID, + } + owner := &database.User{ + ID: 2, + UUID: ownerUUID, + } + + mockUserStore.EXPECT().FindByUUID(ctx, userUUID).Return(user, nil).Once() + mockUserStore.EXPECT().FindByUUID(ctx, ownerUUID).Return(owner, nil).Once() + mockOrgStore.EXPECT().GetUserBelongOrgs(ctx, user.ID).Return(nil, errors.New("database error")).Once() + + result, err := comp.checkOrganization(ctx, userUUID, ownerUUID) + + assert.Error(t, err) + assert.False(t, result, "Should return false when there's an error getting user organizations") + }) + + t.Run("error getting owner organizations - should return false without error", func(t *testing.T) { + ctx := context.Background() + userUUID := "user-uuid-666" + ownerUUID := "owner-uuid-777" + + user := &database.User{ + ID: 66, + UUID: userUUID, + } + owner := &database.User{ + ID: 77, + UUID: ownerUUID, + } + + userOrgs := []database.Organization{ + {ID: 100, Name: "org1"}, + } + + mockUserStore.EXPECT().FindByUUID(ctx, userUUID).Return(user, nil).Once() + mockUserStore.EXPECT().FindByUUID(ctx, ownerUUID).Return(owner, nil).Once() + mockOrgStore.EXPECT().GetUserBelongOrgs(ctx, user.ID).Return(userOrgs, nil).Once() + mockOrgStore.EXPECT().GetUserBelongOrgs(ctx, owner.ID).Return(nil, errors.New("database error")).Once() + + result, err := comp.checkOrganization(ctx, userUUID, ownerUUID) + + assert.Error(t, err) + assert.False(t, result, "Should return false when there's an error getting owner organizations") + }) +} + func TestOpenAIComponent_ExtGetAvailableModels_Error(t *testing.T) { ctx := context.Background() mockLLMConfigStore := mockdb.NewMockLLMConfigStore(t) @@ -439,6 +661,7 @@ func TestOpenAIComponentImpl_RecordUsage(t *testing.T) { mockUserStore := &mockdb.MockUserStore{} mockDeployStore := &mockdb.MockDeployTaskStore{} + mockOrgStore := &mockdb.MockOrgStore{} var mockCounter *mocktoken.MockCounter var comp *openaiComponentImpl @@ -452,13 +675,89 @@ func TestOpenAIComponentImpl_RecordUsage(t *testing.T) { setupMock func() }{ { - name: "successful record - dedicated inference", + name: "successful record - dedicated inference by other user but not same organ", + userUUID: "test-user-uuid", + model: &types.Model{ + InternalModelInfo: types.InternalModelInfo{ + CSGHubModelID: "test-model", + SvcName: "test-service", + SvcType: commontypes.InferenceType, + OwnerUUID: "another-user-uuid", + }, + }, + usage: &openai.CompletionUsage{ + PromptTokens: 100, + CompletionTokens: 50, + TotalTokens: 150, + }, + wantError: false, + setupMock: func() { + + mockMQ := mockmq.NewMockMessageQueue(t) + mockBLDMQ := mockbldmq.NewMockMessageQueue(t) + + eventPub := &event.EventPublisher{ + Connector: mockMQ, + SyncInterval: 1, + MQ: mockBLDMQ, + Cfg: cfg, + } + mockCounter = mocktoken.NewMockCounter(t) + + comp = &openaiComponentImpl{ + userStore: mockUserStore, + deployStore: mockDeployStore, + eventPub: eventPub, + organStore: mockOrgStore, + } + mockCounter.EXPECT().Usage(mock.Anything).Return(&token.Usage{ + PromptTokens: 100, + CompletionTokens: 50, + TotalTokens: 150, + }, nil) + user := &database.User{ + ID: 1, + Username: "testuser", + } + owner := &database.User{ + ID: 2, + Username: "owneruser", + } + mockUserStore.EXPECT().FindByUUID(mock.Anything, "test-user-uuid").Return(user, nil).Once() + mockUserStore.EXPECT().FindByUUID(mock.Anything, "another-user-uuid").Return(owner, nil).Once() + mockOrgStore.EXPECT().GetUserBelongOrgs(mock.Anything, user.ID).Return([]database.Organization{}, nil).Once() + mockBLDMQ.EXPECT().Publish(bldmq.MeterDurationSendSubject, mock.Anything).RunAndReturn(func(topic string, data []byte) error { + var evt commontypes.MeteringEvent + err := json.Unmarshal(data, &evt) + require.NoError(t, err) + require.Equal(t, "test-model", evt.ResourceID) + require.Equal(t, "test-model", evt.ResourceName) + require.Equal(t, "test-service", evt.CustomerID) + require.Equal(t, int(commontypes.SceneModelServerless), evt.Scene) + require.Equal(t, "test-user-uuid", evt.UserUUID) + require.Equal(t, commontypes.TokenNumberType, evt.ValueType) + require.Equal(t, int64(150), evt.Value) + var tokenUsageExtra struct { + PromptTokenNum string `json:"prompt_token_num"` + CompletionTokenNum string `json:"completion_token_num"` + } + err = json.Unmarshal([]byte(evt.Extra), &tokenUsageExtra) + require.NoError(t, err) + require.Equal(t, "100", tokenUsageExtra.PromptTokenNum) + require.Equal(t, "50", tokenUsageExtra.CompletionTokenNum) + return nil + }) + }, + }, + { + name: "successful record - dedicated inference deployed by same user", userUUID: "test-user-uuid", model: &types.Model{ InternalModelInfo: types.InternalModelInfo{ CSGHubModelID: "test-model", SvcName: "test-service", SvcType: commontypes.InferenceType, + OwnerUUID: "test-user-uuid", }, }, usage: &openai.CompletionUsage{ @@ -498,7 +797,7 @@ func TestOpenAIComponentImpl_RecordUsage(t *testing.T) { require.Equal(t, "test-model", evt.ResourceID) require.Equal(t, "test-model", evt.ResourceName) require.Equal(t, "test-service", evt.CustomerID) - require.Equal(t, int(commontypes.SceneModelInference), evt.Scene) + require.Equal(t, int(commontypes.SceneModelServerless), evt.Scene) require.Equal(t, "test-user-uuid", evt.UserUUID) require.Equal(t, commontypes.TokenNumberType, evt.ValueType) require.Equal(t, int64(150), evt.Value) @@ -614,6 +913,7 @@ func TestOpenAIComponentImpl_RecordUsage(t *testing.T) { CSGHubModelID: "test-model", SvcName: "test-service", SvcType: commontypes.InferenceType, + OwnerUUID: "test-user-uuid", }, }, usage: &openai.CompletionUsage{ @@ -660,3 +960,222 @@ func TestOpenAIComponentImpl_RecordUsage(t *testing.T) { }) } } + +func TestOpenAIComponentImpl_RecordUsage_ExternalModel(t *testing.T) { + cfg, err := config.LoadConfig() + require.Nil(t, err) + + mockUserStore := &mockdb.MockUserStore{} + mockDeployStore := &mockdb.MockDeployTaskStore{} + + var mockCounter *mocktoken.MockCounter + var comp *openaiComponentImpl + + tests := []struct { + name string + userUUID string + model *types.Model + wantError bool + setupMock func() + }{ + { + name: "successful record - external model with OpenAI provider", + userUUID: "test-user-uuid", + model: &types.Model{ + BaseModel: types.BaseModel{ + ID: "gpt-4", + OwnedBy: "openai", + }, + ExternalModelInfo: types.ExternalModelInfo{ + Provider: "openai", + }, + }, + wantError: false, + setupMock: func() { + mockMQ := mockmq.NewMockMessageQueue(t) + mockBLDMQ := mockbldmq.NewMockMessageQueue(t) + + eventPub := &event.EventPublisher{ + Connector: mockMQ, + SyncInterval: 1, + MQ: mockBLDMQ, + Cfg: cfg, + } + mockCounter = mocktoken.NewMockCounter(t) + + comp = &openaiComponentImpl{ + userStore: mockUserStore, + deployStore: mockDeployStore, + eventPub: eventPub, + } + mockCounter.EXPECT().Usage(mock.Anything).Return(&token.Usage{ + PromptTokens: 200, + CompletionTokens: 100, + TotalTokens: 300, + }, nil) + + mockBLDMQ.EXPECT().Publish(bldmq.MeterDurationSendSubject, mock.Anything).RunAndReturn(func(topic string, data []byte) error { + var evt commontypes.MeteringEvent + err := json.Unmarshal(data, &evt) + require.NoError(t, err) + require.Equal(t, "gpt-4", evt.ResourceID) + require.Equal(t, "gpt-4", evt.ResourceName) + require.Equal(t, "test-user-uuid", evt.UserUUID) + require.Equal(t, commontypes.TokenNumberType, evt.ValueType) + require.Equal(t, int64(300), evt.Value) + require.Equal(t, int(commontypes.SceneModelServerless), evt.Scene) + + var tokenUsageExtra struct { + PromptTokenNum string `json:"prompt_token_num"` + CompletionTokenNum string `json:"completion_token_num"` + OwnerType commontypes.TokenUsageType + } + err = json.Unmarshal([]byte(evt.Extra), &tokenUsageExtra) + require.NoError(t, err) + require.Equal(t, "200", tokenUsageExtra.PromptTokenNum) + require.Equal(t, "100", tokenUsageExtra.CompletionTokenNum) + require.Equal(t, commontypes.ExternalInference, tokenUsageExtra.OwnerType) + return nil + }) + }, + }, + { + name: "counter error for external model", + userUUID: "test-user-uuid", + model: &types.Model{ + BaseModel: types.BaseModel{ + ID: "gpt-3.5-turbo", + OwnedBy: "openai", + }, + ExternalModelInfo: types.ExternalModelInfo{ + Provider: "openai", + }, + }, + wantError: true, + setupMock: func() { + mockMQ := mockmq.NewMockMessageQueue(t) + mockBLDMQ := mockbldmq.NewMockMessageQueue(t) + eventPub := &event.EventPublisher{ + Connector: mockMQ, + SyncInterval: 1, + MQ: mockBLDMQ, + } + mockCounter = mocktoken.NewMockCounter(t) + + comp = &openaiComponentImpl{ + userStore: mockUserStore, + deployStore: mockDeployStore, + eventPub: eventPub, + } + mockCounter.EXPECT().Usage(mock.Anything).Return(nil, errors.New("counter error")) + }, + }, + { + name: "publish error for external model", + userUUID: "test-user-uuid", + model: &types.Model{ + BaseModel: types.BaseModel{ + ID: "gemini-pro", + OwnedBy: "google", + }, + ExternalModelInfo: types.ExternalModelInfo{ + Provider: "google", + }, + }, + wantError: true, + setupMock: func() { + mockMQ := mockmq.NewMockMessageQueue(t) + mockBLDMQ := mockbldmq.NewMockMessageQueue(t) + eventPub := &event.EventPublisher{ + Connector: mockMQ, + SyncInterval: 1, + MQ: mockBLDMQ, + Cfg: cfg, + } + mockCounter = mocktoken.NewMockCounter(t) + comp = &openaiComponentImpl{ + userStore: mockUserStore, + deployStore: mockDeployStore, + eventPub: eventPub, + } + mockCounter.EXPECT().Usage(mock.Anything).Return(&token.Usage{ + PromptTokens: 50, + CompletionTokens: 25, + TotalTokens: 75, + }, nil) + mockBLDMQ.EXPECT().Publish(bldmq.MeterDurationSendSubject, mock.Anything).Return(errors.New("publish error")).Times(3) + }, + }, + { + name: "external model with zero tokens", + userUUID: "test-user-uuid", + model: &types.Model{ + BaseModel: types.BaseModel{ + ID: "test-model", + OwnedBy: "test-provider", + }, + ExternalModelInfo: types.ExternalModelInfo{ + Provider: "test-provider", + }, + }, + wantError: false, + setupMock: func() { + mockMQ := mockmq.NewMockMessageQueue(t) + mockBLDMQ := mockbldmq.NewMockMessageQueue(t) + + eventPub := &event.EventPublisher{ + Connector: mockMQ, + SyncInterval: 1, + MQ: mockBLDMQ, + Cfg: cfg, + } + mockCounter = mocktoken.NewMockCounter(t) + + comp = &openaiComponentImpl{ + userStore: mockUserStore, + deployStore: mockDeployStore, + eventPub: eventPub, + } + mockCounter.EXPECT().Usage(mock.Anything).Return(&token.Usage{ + PromptTokens: 0, + CompletionTokens: 0, + TotalTokens: 0, + }, nil) + + mockBLDMQ.EXPECT().Publish(bldmq.MeterDurationSendSubject, mock.Anything).RunAndReturn(func(topic string, data []byte) error { + var evt commontypes.MeteringEvent + err := json.Unmarshal(data, &evt) + require.NoError(t, err) + require.Equal(t, "test-model", evt.ResourceID) + require.Equal(t, "test-model", evt.ResourceName) + require.Equal(t, int64(0), evt.Value) + + var tokenUsageExtra struct { + PromptTokenNum string `json:"prompt_token_num"` + CompletionTokenNum string `json:"completion_token_num"` + OwnerType commontypes.TokenUsageType + } + err = json.Unmarshal([]byte(evt.Extra), &tokenUsageExtra) + require.NoError(t, err) + require.Equal(t, "0", tokenUsageExtra.PromptTokenNum) + require.Equal(t, "0", tokenUsageExtra.CompletionTokenNum) + require.Equal(t, commontypes.ExternalInference, tokenUsageExtra.OwnerType) + return nil + }) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.setupMock() + + err := comp.RecordUsage(context.Background(), tt.userUUID, tt.model, mockCounter) + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/aigateway/handler/openai.go b/aigateway/handler/openai.go index f94b0c769..c54a3aa49 100644 --- a/aigateway/handler/openai.go +++ b/aigateway/handler/openai.go @@ -10,6 +10,7 @@ import ( "net/http" "net/url" "strings" + "time" "github.com/gin-gonic/gin" "github.com/openai/openai-go/v3" @@ -320,7 +321,9 @@ func (h *OpenAIHandlerImpl) Chat(c *gin.Context) { rp.ServeHTTP(w, c.Request, proxyToApi, host) go func() { - err := h.openaiComponent.RecordUsage(c.Request.Context(), userUUID, model, tokenCounter) + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + err := h.openaiComponent.RecordUsage(ctx, userUUID, model, tokenCounter) if err != nil { slog.Error("failed to record token usage", "error", err) } diff --git a/aigateway/types/openai.go b/aigateway/types/openai.go index 851780e74..f18ad3323 100644 --- a/aigateway/types/openai.go +++ b/aigateway/types/openai.go @@ -19,6 +19,7 @@ type BaseModel struct { // InternalModelInfo represents the internal model fields type InternalModelInfo struct { CSGHubModelID string `json:"-"` // the internal model id (repo path) in CSGHub + OwnerUUID string `json:"-"` // the uuid of deploy owner ClusterID string `json:"-"` // the deployed cluster id in CSGHub SvcName string `json:"-"` // the internal service name in CSGHub SvcType int `json:"-"` // the internal service type like dedicated or serverless in CSGHub diff --git a/common/types/accounting.go b/common/types/accounting.go index a194c525d..3df22eaa1 100644 --- a/common/types/accounting.go +++ b/common/types/accounting.go @@ -107,6 +107,16 @@ var ( SceneUnknow SceneType = 99 // unknow ) +type TokenUsageType string + +var ( + ExternalInference TokenUsageType = "0" + CSGHubUserDeployedInference TokenUsageType = "1" + CSGHubOtherDeployedInference TokenUsageType = "2" + CSGHubServerlessInference TokenUsageType = "3" + CSGHubOrganFellowDeployedInference TokenUsageType = "4" +) + type ChargeValueType int var (