Skip to content

Commit 00b699e

Browse files
Add LiteLLM Wrapper support (AST-113982) (#38)
* Add LiteLLM Wrapper support (AST-113982) * Add LiteLLM Wrapper support (AST-113982) * Add LiteLLM Wrapper support (AST-113982)
1 parent 73a8769 commit 00b699e

File tree

7 files changed

+217
-149
lines changed

7 files changed

+217
-149
lines changed

example/cxoneai.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,14 @@ func getOAuthAccessToken() (string, error) {
5050
}
5151

5252
data := url.Values{}
53-
data.Set("grant_type", "client_credentials")
53+
data.Set("grant_type", "refresh_token")
5454
data.Set("client_id", clientID)
55-
data.Set("client_secret", clientSecret)
55+
data.Set("refresh_token", clientSecret)
56+
57+
//Use this if you have client credentials
58+
//data.Set("grant_type", "client_credentials")
59+
//data.Set("client_id", clientID)
60+
//data.Set("client_secret", clientSecret)
5661

5762
req, err := http.NewRequest("POST", openIDURL, strings.NewReader(data.Encode()))
5863
if err != nil {

example/main.go

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@ package main
33
import (
44
"flag"
55
"fmt"
6-
"github.com/Checkmarx/gen-ai-wrapper/pkg/connector"
6+
"os"
7+
"strings"
8+
9+
"github.com/Checkmarx/gen-ai-wrapper/internal"
710
"github.com/Checkmarx/gen-ai-wrapper/pkg/message"
811
"github.com/Checkmarx/gen-ai-wrapper/pkg/models"
912
"github.com/Checkmarx/gen-ai-wrapper/pkg/role"
1013
"github.com/Checkmarx/gen-ai-wrapper/pkg/wrapper"
1114
"github.com/google/uuid"
12-
"os"
13-
"strings"
1415
)
1516

1617
const usage = `
@@ -23,9 +24,8 @@ Options
2324
-s, --system <system-prompt> system (or developer) prompt string
2425
-u, --user <user-prompt> user prompt string
2526
-id <conversation-id> chat conversation ID
26-
-ai <ai-server> AI server to use. Options: {OpenAI (default), CxOne}
27-
-m, --model <model> model to use. Options: {gpt-4o (default), gpt-4, o1, o1-mini, ...}
28-
-f, --full-response return full response from AI
27+
-ai <ai-server> AI server to use. Options: {OpenAI (default), CxOne, LiteLLM}
28+
-m, --model <model> model to use. Options: {gpt-4o (default), gpt-4, o1, o1-mini, claude-3-5-sonnet-20241022, ...}
2929
-h, --help show help
3030
`
3131

@@ -103,26 +103,46 @@ func CallAIandPrintResponse(aiServer, model, systemPrompt, userPrompt string, ch
103103
return err
104104
}
105105

106-
statefulWrapper, err := wrapper.NewStatefulWrapperNew(
107-
connector.NewFileSystemConnector(""), aiEndpoint, aiKey, model, 4, 0)
106+
var litellmWrapper wrapper.LitellmWrapper
107+
108+
// Use litellm wrapper for litellm server
109+
if strings.EqualFold(aiServer, "LiteLLM") {
110+
litellmWrapper, err = wrapper.NewLitellmWrapper(aiEndpoint, aiKey, model)
111+
} else {
112+
// For other servers, we'll need to implement or use existing wrappers
113+
return fmt.Errorf("unsupported AI server: %s", aiServer)
114+
}
115+
108116
if err != nil {
109117
return fmt.Errorf("error creating '%s' AI client: %v", aiServer, err)
110118
}
111119

112120
newMessages := GetMessages(model, systemPrompt, userPrompt)
113121

122+
// Create proper metadata for the request
123+
metaData := &message.MetaData{
124+
RequestID: chatId.String(),
125+
TenantID: "default-tenant",
126+
UserAgent: "gen-ai-wrapper-example",
127+
Feature: "chat-completion",
128+
}
129+
130+
// Create the request
131+
request := &internal.ChatCompletionRequest{
132+
Model: model,
133+
Messages: newMessages,
134+
}
135+
136+
// Make the call
137+
response, err := litellmWrapper.Call(aiKey, metaData, request)
138+
if err != nil {
139+
return fmt.Errorf("error calling litellm: %v", err)
140+
}
141+
114142
if fullResponse {
115-
response, err := statefulWrapper.SecureCallReturningFullResponse("", nil, chatId, newMessages)
116-
if err != nil {
117-
return fmt.Errorf("error calling GPT: %v", err)
118-
}
119143
fmt.Printf("%+v\n", response)
120144
} else {
121-
response, err := statefulWrapper.Call(chatId, newMessages)
122-
if err != nil {
123-
return fmt.Errorf("error calling GPT: %v", err)
124-
}
125-
fmt.Println(getMessageContents(response))
145+
fmt.Println(response.Choices[0].Message.Content)
126146
}
127147
return nil
128148
}
@@ -156,7 +176,7 @@ func getAIAccessKey(aiServer, model string) (string, error) {
156176
}
157177
return accessKey, nil
158178
}
159-
if strings.EqualFold(aiServer, "CxOne") {
179+
if strings.EqualFold(aiServer, "CxOne") || strings.EqualFold(aiServer, "LiteLLM") {
160180
accessKey, err := GetCxOneAIAccessKey()
161181
if err != nil {
162182
return "", fmt.Errorf("error getting CxOne AI API key: %v", err)
@@ -174,7 +194,7 @@ func getAIEndpoint(aiServer string) (string, error) {
174194
}
175195
return aiEndpoint, nil
176196
}
177-
if strings.EqualFold(aiServer, "CxOne") {
197+
if strings.EqualFold(aiServer, "CxOne") || strings.EqualFold(aiServer, "LiteLLM") {
178198
aiEndpoint, err := GetCxOneAIEndpoint()
179199
if err != nil {
180200
return "", fmt.Errorf("error getting CxOne AI endpoint: %v", err)

internal/genaiProxyInternal.go

Lines changed: 0 additions & 119 deletions
This file was deleted.

internal/gpt.go

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package internal
33
import (
44
"errors"
55
"fmt"
6-
"net/url"
76

87
"github.com/Checkmarx/gen-ai-wrapper/pkg/message"
98
"github.com/Checkmarx/gen-ai-wrapper/pkg/role"
@@ -49,14 +48,12 @@ type Wrapper interface {
4948
}
5049

5150
func NewWrapperFactory(endPoint, apiKey string, dropLen int) (Wrapper, error) {
52-
endPointURL, err := url.Parse(endPoint)
53-
if err != nil {
54-
return nil, err
55-
}
56-
if endPointURL.Scheme == "http" || endPointURL.Scheme == "https" {
57-
return NewWrapperImpl(endPoint, apiKey, dropLen), nil
58-
}
59-
return NewWrapperInternalImpl(endPoint, dropLen)
51+
return NewWrapperImpl(endPoint, apiKey, dropLen), nil
52+
}
53+
54+
// NewLitellmWrapperFactory creates a new litellm wrapper factory
55+
func NewLitellmWrapperFactory(endPoint, apiKey string) (Wrapper, error) {
56+
return NewLitellmWrapper(endPoint, apiKey), nil
6057
}
6158

6259
func fromResponse(statusCode int, e *ErrorResponse) error {

internal/litellm_wrapper.go

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
package internal
2+
3+
import (
4+
"bytes"
5+
"encoding/json"
6+
"fmt"
7+
"io"
8+
"net/http"
9+
10+
"github.com/Checkmarx/gen-ai-wrapper/pkg/message"
11+
)
12+
13+
// LitellmWrapper implements the Wrapper interface for litellm AI proxy service
14+
type LitellmWrapper struct {
15+
endPoint string
16+
apiKey string
17+
}
18+
19+
// NewLitellmWrapper creates a new litellm wrapper instance
20+
func NewLitellmWrapper(endPoint, apiKey string) Wrapper {
21+
return &LitellmWrapper{
22+
endPoint: endPoint,
23+
apiKey: apiKey,
24+
}
25+
}
26+
27+
// SetupCall sets up the wrapper with initial messages (no-op for litellm)
28+
func (w *LitellmWrapper) SetupCall(messages []message.Message) {
29+
// No setup needed for litellm
30+
}
31+
32+
// Call makes a request to the litellm AI proxy service
33+
func (w *LitellmWrapper) Call(cxAuth string, metaData *message.MetaData, request *ChatCompletionRequest) (*ChatCompletionResponse, error) {
34+
// Prepare the request
35+
req, err := w.prepareRequest(cxAuth, metaData, request)
36+
if err != nil {
37+
return nil, err
38+
}
39+
40+
// Make the HTTP request
41+
resp, err := http.DefaultClient.Do(req)
42+
if err != nil {
43+
return nil, err
44+
}
45+
defer resp.Body.Close()
46+
47+
// Handle the response
48+
return w.handleResponse(resp)
49+
}
50+
51+
// prepareRequest creates the HTTP request
52+
func (w *LitellmWrapper) prepareRequest(cxAuth string, metaData *message.MetaData, requestBody *ChatCompletionRequest) (*http.Request, error) {
53+
jsonData, err := json.Marshal(requestBody)
54+
if err != nil {
55+
return nil, err
56+
}
57+
58+
req, err := http.NewRequest(http.MethodPost, w.endPoint, bytes.NewBuffer(jsonData))
59+
if err != nil {
60+
return nil, err
61+
}
62+
63+
// Set headers
64+
req.Header.Set("Content-Type", "application/json")
65+
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", cxAuth))
66+
67+
// Set required headers for litellm service
68+
req.Header.Set("X-Request-ID", metaData.RequestID)
69+
req.Header.Set("X-Tenant-ID", metaData.TenantID)
70+
req.Header.Set("User-Agent", metaData.UserAgent)
71+
req.Header.Set("X-Feature", metaData.Feature)
72+
73+
return req, nil
74+
}
75+
76+
// handleResponse processes the HTTP response
77+
func (w *LitellmWrapper) handleResponse(resp *http.Response) (*ChatCompletionResponse, error) {
78+
bodyBytes, err := io.ReadAll(resp.Body)
79+
if err != nil {
80+
return nil, err
81+
}
82+
83+
// Handle successful response
84+
if resp.StatusCode == http.StatusOK {
85+
var responseBody = new(ChatCompletionResponse)
86+
err = json.Unmarshal(bodyBytes, responseBody)
87+
if err != nil {
88+
return nil, err
89+
}
90+
return responseBody, nil
91+
}
92+
93+
// Handle error responses
94+
var errorResponse = new(ErrorResponse)
95+
err = json.Unmarshal(bodyBytes, errorResponse)
96+
if err != nil {
97+
// If we can't parse the error response, return a generic error
98+
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(bodyBytes))
99+
}
100+
101+
// Return the parsed error
102+
return nil, fromResponse(resp.StatusCode, errorResponse)
103+
}
104+
105+
// Close closes the wrapper (no-op for HTTP client)
106+
func (w *LitellmWrapper) Close() error {
107+
return nil
108+
}

pkg/models/models.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,8 @@ const (
1515
GPT3TextDavinci001 = "text-davinci-001"
1616
GPT3TextDavinci002 = "text-davinci-002"
1717
GPT3TextDavinci003 = "text-davinci-003"
18+
ClaudeSonnet37 = "claude-sonnet-3-7"
19+
ClaudeSonnet4 = "claude-sonnet-4"
20+
ClaudeSonnet45 = "claude-sonnet-4-5"
1821
DefaultModel = GPT4o
1922
)

0 commit comments

Comments
 (0)