Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 97 additions & 11 deletions aigateway/component/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ type OpenAIComponent interface {

type openaiComponentImpl struct {
userStore database.UserStore
organStore database.OrgStore
deployStore database.DeployTaskStore
eventPub *event.EventPublisher
extllmStore database.LLMConfigStore
Expand All @@ -51,6 +52,7 @@ func NewOpenAIComponentFromConfig(config *config.Config) (OpenAIComponent, error
}
return &openaiComponentImpl{
userStore: database.NewUserStore(),
organStore: database.NewOrgStore(),
deployStore: database.NewDeployTaskStore(),
eventPub: &event.DefaultEventPublisher,
extllmStore: database.NewLLMConfigStore(config),
Expand Down Expand Up @@ -108,6 +110,7 @@ func (m *openaiComponentImpl) getCSGHubModels(c context.Context, userName string
},
InternalModelInfo: types.InternalModelInfo{
CSGHubModelID: deploy.Repository.Path,
OwnerUUID: deploy.User.UUID,
ClusterID: deploy.ClusterID,
SvcName: deploy.SvcName,
SvcType: deploy.Type,
Expand Down Expand Up @@ -284,23 +287,67 @@ func (m *openaiComponentImpl) RecordUsage(c context.Context, userUUID string, mo
var tokenUsageExtra = struct {
PromptTokenNum string `json:"prompt_token_num"`
CompletionTokenNum string `json:"completion_token_num"`
// 0: external, 1: owner is user, 2: other user is inference, 3: serverless
OwnerType commontypes.TokenUsageType
}{
PromptTokenNum: fmt.Sprintf("%d", usage.PromptTokens),
CompletionTokenNum: fmt.Sprintf("%d", usage.CompletionTokens),
}
if model.CSGHubModelID != "" && model.Provider != "" {
slog.WarnContext(c, "bad model info, both csghub model id and external model provider is set",
slog.Any("model info", model))
}
if model.CSGHubModelID == "" && model.Provider == "" {
slog.WarnContext(c, "bad model info, both csghub model id and external model provider is not set",
slog.Any("model info", model))
}
if model.CSGHubModelID != "" {
switch model.SvcType {
case commontypes.ServerlessType:
tokenUsageExtra.OwnerType = commontypes.CSGHubServerlessInference
case commontypes.InferenceType:
if model.OwnerUUID == userUUID {
tokenUsageExtra.OwnerType = commontypes.CSGHubUserDeployedInference
} else {
belong, err := m.checkOrganization(c, userUUID, model.OwnerUUID)
if err != nil {
return fmt.Errorf("failed to check organization,error:%w", err)
}
if belong {
tokenUsageExtra.OwnerType = commontypes.CSGHubOrganFellowDeployedInference
} else {
tokenUsageExtra.OwnerType = commontypes.CSGHubOtherDeployedInference
}
}
default:
slog.WarnContext(c, "bad model info, csghub model missing service type",
slog.Any("model info", model))
}
}
if model.Provider != "" {
tokenUsageExtra.OwnerType = commontypes.ExternalInference
}

extraData, _ := json.Marshal(tokenUsageExtra)
event := commontypes.MeteringEvent{
Uuid: uuid.New(),
UserUUID: userUUID,
Value: usage.TotalTokens,
ValueType: commontypes.TokenNumberType, // count by token
Scene: getSceneFromSvcType(model.SvcType),
OpUID: "",
ResourceID: model.CSGHubModelID,
ResourceName: model.CSGHubModelID,
CustomerID: model.SvcName,
CreatedAt: time.Now(),
Extra: string(extraData),
Uuid: uuid.New(),
UserUUID: userUUID,
Value: usage.TotalTokens,
ValueType: commontypes.TokenNumberType, // count by token
Scene: int(commontypes.SceneModelServerless),
OpUID: "aigateway",
CreatedAt: time.Now(),
Extra: string(extraData),
}
if model.CSGHubModelID != "" {
event.ResourceID = model.CSGHubModelID
event.ResourceName = model.CSGHubModelID
event.CustomerID = model.SvcName
}
if model.Provider != "" {
event.ResourceID = model.ID
event.ResourceName = model.ID
event.CustomerID = model.Provider
}
eventData, _ := json.Marshal(event)
err = m.eventPub.PublishMeteringEvent(eventData)
Expand All @@ -312,3 +359,42 @@ func (m *openaiComponentImpl) RecordUsage(c context.Context, userUUID string, mo
slog.Info("public token usage event success", "event", event)
return nil
}

func (m *openaiComponentImpl) checkOrganization(c context.Context, userUUID string, ownerUUID string) (bool, error) {
user, err := m.userStore.FindByUUID(c, userUUID)
if err != nil {
slog.ErrorContext(c, "Failed to find user in db")
return false, err
}
owner, err := m.userStore.FindByUUID(c, ownerUUID)
if err != nil {
slog.ErrorContext(c, "Failed to find owner in db")
return false, err
}
userOrgs, err := m.organStore.GetUserBelongOrgs(c, user.ID)
if err != nil {
slog.ErrorContext(c, "Failed to find user organizations")
return false, err
}
if len(userOrgs) == 0 {
return false, nil
}
ownerOrgs, err := m.organStore.GetUserBelongOrgs(c, owner.ID)
if err != nil {
slog.ErrorContext(c, "Failed to find owner organizations")
return false, err
}
if len(ownerOrgs) == 0 {
return false, nil
}
userOrgansMap := make(map[int64]struct{}, len(userOrgs))
for _, org := range userOrgs {
userOrgansMap[org.ID] = struct{}{}
}
for _, org := range ownerOrgs {
if _, ok := userOrgansMap[org.ID]; ok {
return true, nil
}
}
return false, nil
}
Loading
Loading