Skip to content

Commit 942a9d9

Browse files
committed
#153: Covered the override param logic by tests & fixed anthropic tests to use chatparams
1 parent 0e511a4 commit 942a9d9

File tree

3 files changed

+177
-7
lines changed

3 files changed

+177
-7
lines changed

pkg/api/schemas/chat.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ package schemas
33
// ChatRequest defines Glide's Chat Request Schema unified across all language models
44
type ChatRequest struct {
55
Message ChatMessage `json:"message" validate:"required"`
6-
MessageHistory []ChatMessage `json:"message_history"`
6+
MessageHistory []ChatMessage `json:"message_history,omitempty"`
77
OverrideParams *map[string]ModelParamsOverride `json:"override_params,omitempty"`
88
}
99

pkg/api/schemas/chat_test.go

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
package schemas
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/require"
7+
)
8+
9+
func ToSlice(messageHistory []ChatMessage) []string {
10+
history := make([]string, 0, len(messageHistory))
11+
12+
for _, message := range messageHistory {
13+
history = append(history, message.Content)
14+
}
15+
16+
return history
17+
}
18+
19+
// TestChatRequest_DefaultParams tests param creation for a request
20+
// that doesn't have any override for a given model ID/name
21+
func TestChatRequest_DefaultParams(t *testing.T) {
22+
backstory := "You are talking to a guy who won an ACMP contest in 2015"
23+
defaultMessage := "When did I win an ACMP contest?"
24+
25+
modelID := "my-openai-model"
26+
myModelMessage := "When did he win the contest?Be concise"
27+
28+
secondModelID := "my-other-model"
29+
secondModelName := "command-r"
30+
31+
chatReq := ChatRequest{
32+
Message: ChatMessage{
33+
Role: "user",
34+
Content: defaultMessage,
35+
},
36+
MessageHistory: []ChatMessage{
37+
{
38+
Role: "system",
39+
Content: backstory,
40+
},
41+
},
42+
OverrideParams: &map[string]ModelParamsOverride{
43+
modelID: {
44+
Message: ChatMessage{
45+
Role: "user",
46+
Content: myModelMessage,
47+
},
48+
},
49+
},
50+
}
51+
52+
params := chatReq.Params(secondModelID, secondModelName)
53+
54+
require.Equal(t, []string{backstory, defaultMessage}, ToSlice(params.Messages))
55+
}
56+
57+
// TestChatRequest_ModelIDOverride tests param creation for a request
58+
// that has a param override for a modelID
59+
func TestChatRequest_ModelIDOverride(t *testing.T) {
60+
backstory := "You are talking to a guy who won an ACMP contest in 2015"
61+
defaultMessage := "When did I win an ACMP contest?"
62+
63+
modelID := "my-openai-model"
64+
modelName := "gpt-4"
65+
myModelMessage := "When did he win the contest?Be concise"
66+
67+
chatReq := ChatRequest{
68+
Message: ChatMessage{
69+
Role: "user",
70+
Content: defaultMessage,
71+
},
72+
MessageHistory: []ChatMessage{
73+
{
74+
Role: "system",
75+
Content: backstory,
76+
},
77+
},
78+
OverrideParams: &map[string]ModelParamsOverride{
79+
modelID: {
80+
Message: ChatMessage{
81+
Role: "user",
82+
Content: myModelMessage,
83+
},
84+
},
85+
},
86+
}
87+
88+
params := chatReq.Params(modelID, modelName)
89+
90+
require.Equal(t, []string{backstory, myModelMessage}, ToSlice(params.Messages))
91+
}
92+
93+
// TestChatRequest_ModelNameOverride tests param creation for a request
94+
// that has a param override for a modelName
95+
func TestChatRequest_ModelNameOverride(t *testing.T) {
96+
backstory := "You are talking to a guy who won an ACMP contest in 2015"
97+
defaultMessage := "When did I win an ACMP contest?"
98+
99+
modelID := "my-openai-model"
100+
modelName := "gpt-4"
101+
myModelMessage := "When did he win the contest?Be concise"
102+
103+
chatReq := ChatRequest{
104+
Message: ChatMessage{
105+
Role: "user",
106+
Content: defaultMessage,
107+
},
108+
MessageHistory: []ChatMessage{
109+
{
110+
Role: "system",
111+
Content: backstory,
112+
},
113+
},
114+
OverrideParams: &map[string]ModelParamsOverride{
115+
modelName: {
116+
Message: ChatMessage{
117+
Role: "user",
118+
Content: myModelMessage,
119+
},
120+
},
121+
},
122+
}
123+
124+
params := chatReq.Params(modelID, modelName)
125+
126+
require.Equal(t, []string{backstory, myModelMessage}, ToSlice(params.Messages))
127+
}
128+
129+
// TestChatRequest_ModelNameOverride tests param creation for a request
130+
// that has a param override for both modelName & modelID
131+
func TestChatRequest_ModelNameIDOverride(t *testing.T) {
132+
backstory := "You are talking to a guy who won an ACMP contest in 2015"
133+
defaultMessage := "When did I win an ACMP contest?"
134+
135+
modelID := "my-openai-model"
136+
modelName := "gpt-4"
137+
myModelIDMessage := "When did he win the contest?Be concise"
138+
myModelNameMessage := "When did he win the contest? Answer like Illya would"
139+
140+
chatReq := ChatRequest{
141+
Message: ChatMessage{
142+
Role: "user",
143+
Content: defaultMessage,
144+
},
145+
MessageHistory: []ChatMessage{
146+
{
147+
Role: "system",
148+
Content: backstory,
149+
},
150+
},
151+
OverrideParams: &map[string]ModelParamsOverride{
152+
modelName: {
153+
Message: ChatMessage{
154+
Role: "user",
155+
Content: myModelNameMessage,
156+
},
157+
},
158+
modelID: {
159+
Message: ChatMessage{
160+
Role: "user",
161+
Content: myModelIDMessage,
162+
},
163+
},
164+
},
165+
}
166+
167+
params := chatReq.Params(modelID, modelName)
168+
169+
require.Equal(t, []string{backstory, myModelIDMessage}, ToSlice(params.Messages))
170+
}

pkg/providers/anthropic/client_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,12 @@ func TestAnthropicClient_ChatRequest(t *testing.T) {
5656
client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock())
5757
require.NoError(t, err)
5858

59-
request := schemas.ChatRequest{Message: schemas.ChatMessage{
59+
chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{
6060
Role: "human",
6161
Content: "What's the biggest animal?",
62-
}}
62+
}}}
6363

64-
response, err := client.Chat(ctx, &request)
64+
response, err := client.Chat(ctx, &chatParams)
6565
require.NoError(t, err)
6666

6767
require.Equal(t, "msg_013Zva2CMHLNnXjNJJKqJ2EF", response.ID)
@@ -86,12 +86,12 @@ func TestAnthropicClient_BadChatRequest(t *testing.T) {
8686
client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock())
8787
require.NoError(t, err)
8888

89-
request := schemas.ChatRequest{Message: schemas.ChatMessage{
89+
chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{
9090
Role: "human",
9191
Content: "What's the biggest animal?",
92-
}}
92+
}}}
9393

94-
response, err := client.Chat(ctx, &request)
94+
response, err := client.Chat(ctx, &chatParams)
9595

9696
// Assert that an error is returned
9797
require.Error(t, err)

0 commit comments

Comments
 (0)