Skip to content

Commit 80695d3

Browse files
committed
#153: Fixed broken tests
1 parent 85e7646 commit 80695d3

File tree

8 files changed

+58
-44
lines changed

8 files changed

+58
-44
lines changed

pkg/providers/azureopenai/chat_stream_test.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,12 @@ func TestAzureOpenAIClient_ChatStreamRequestInterrupted(t *testing.T) {
139139
client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock())
140140
require.NoError(t, err)
141141

142-
req := schemas.NewChatStreamFromStr("What's the capital of the United Kingdom?")
142+
chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{
143+
Role: "user",
144+
Content: "What's the biggest animal?",
145+
}}}
143146

144-
stream, err := client.ChatStream(ctx, req)
147+
stream, err := client.ChatStream(ctx, &chatParams)
145148
require.NoError(t, err)
146149

147150
err = stream.Open()

pkg/providers/azureopenai/client_test.go

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,12 @@ func TestAzureOpenAIClient_ChatRequest(t *testing.T) {
5555
client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock())
5656
require.NoError(t, err)
5757

58-
request := schemas.ChatRequest{Message: schemas.ChatMessage{
58+
chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{
5959
Role: "user",
60-
Content: "What's the biggest animal?",
61-
}}
60+
Content: "What's the capital of the United Kingdom?",
61+
}}}
6262

63-
response, err := client.Chat(ctx, &request)
63+
response, err := client.Chat(ctx, &chatParams)
6464
require.NoError(t, err)
6565

6666
require.Equal(t, "chatcmpl-8cdqrFT2lBQlHz0EDvvq6oQcRxNcZ", response.ID)
@@ -88,12 +88,12 @@ func TestAzureOpenAIClient_ChatError(t *testing.T) {
8888
client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock())
8989
require.NoError(t, err)
9090

91-
request := schemas.ChatRequest{Message: schemas.ChatMessage{
92-
Role: "user",
91+
chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{
92+
Role: "human",
9393
Content: "What's the biggest animal?",
94-
}}
94+
}}}
9595

96-
response, err := client.Chat(ctx, &request)
96+
response, err := client.Chat(ctx, &chatParams)
9797
require.Error(t, err)
9898
require.Nil(t, response)
9999
}
@@ -115,10 +115,12 @@ func TestDoChatRequest_ErrorResponse(t *testing.T) {
115115
client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock())
116116
require.NoError(t, err)
117117

118-
// Create a chat request payload
119-
payload := schemas.NewChatFromStr("What's the dealio?")
118+
chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{
119+
Role: "user",
120+
Content: "What's the dealio?",
121+
}}}
120122

121-
_, err = client.Chat(ctx, payload)
123+
_, err = client.Chat(ctx, &chatParams)
122124

123125
require.Error(t, err)
124126
require.Contains(t, err.Error(), "provider is not available")

pkg/providers/cohere/chat_stream_test.go

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,12 @@ func TestCohere_ChatStreamRequest(t *testing.T) {
7171
client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock())
7272
require.NoError(t, err)
7373

74-
req := schemas.NewChatStreamFromStr("What's the capital of the United Kingdom?")
74+
chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{
75+
Role: "user",
76+
Content: "What's the capital of the United Kingdom?",
77+
}}}
7578

76-
stream, err := client.ChatStream(ctx, req)
79+
stream, err := client.ChatStream(ctx, &chatParams)
7780
require.NoError(t, err)
7881

7982
err = stream.Open()
@@ -135,8 +138,12 @@ func TestCohere_ChatStreamRequestInterrupted(t *testing.T) {
135138
client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock())
136139
require.NoError(t, err)
137140

138-
req := schemas.NewChatStreamFromStr("What's the capital of the United Kingdom?")
139-
stream, err := client.ChatStream(ctx, req)
141+
chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{
142+
Role: "user",
143+
Content: "What's the capital of the United Kingdom?",
144+
}}}
145+
146+
stream, err := client.ChatStream(ctx, &chatParams)
140147
require.NoError(t, err)
141148

142149
err = stream.Open()

pkg/providers/cohere/client_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,12 @@ func TestCohereClient_ChatRequest(t *testing.T) {
5555
client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock())
5656
require.NoError(t, err)
5757

58-
request := schemas.ChatRequest{Message: schemas.ChatMessage{
58+
chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{
5959
Role: "human",
6060
Content: "What's the biggest animal?",
61-
}}
61+
}}}
6262

63-
response, err := client.Chat(ctx, &request)
63+
response, err := client.Chat(ctx, &chatParams)
6464
require.NoError(t, err)
6565

6666
require.Equal(t, "ec9eb88b-2da5-462e-8f0f-0899d243aa2e", response.ID)

pkg/providers/ollama/client_test.go

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -85,13 +85,12 @@ func TestOllamaClient_ChatRequest_Non200Response(t *testing.T) {
8585
telemetry: telemetry.NewTelemetryMock(),
8686
}
8787

88-
// Create a chat request payload
89-
payload := &ChatRequest{
90-
Messages: []ChatMessage{{Role: "human", Content: "Hello"}},
91-
}
88+
chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{
89+
Role: "user",
90+
Content: "What's the capital of the United Kingdom?",
91+
}}}
9292

93-
// Call the chatRequest function
94-
_, err := client.doChatRequest(context.Background(), payload)
93+
_, err := client.Chat(context.Background(), &chatParams)
9594

9695
require.Error(t, err)
9796
require.Contains(t, err.Error(), "provider is not available")
@@ -114,13 +113,12 @@ func TestOllamaClient_ChatRequest_SuccessfulResponse(t *testing.T) {
114113
telemetry: telemetry.NewTelemetryMock(),
115114
}
116115

117-
// Create a chat request payload
118-
payload := &ChatRequest{
119-
Messages: []ChatMessage{{Role: "human", Content: "Hello"}},
120-
}
116+
chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{
117+
Role: "user",
118+
Content: "What's the capital of the United Kingdom?",
119+
}}}
121120

122-
// Call the chatRequest function
123-
response, err := client.doChatRequest(context.Background(), payload)
121+
response, err := client.Chat(context.Background(), &chatParams)
124122

125123
require.NoError(t, err)
126124
require.NotNil(t, response)

pkg/providers/openai/chat_stream_test.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,12 @@ func TestOpenAIClient_ChatStreamRequestInterrupted(t *testing.T) {
139139
client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock())
140140
require.NoError(t, err)
141141

142-
req := schemas.NewChatStreamFromStr("What's the capital of the United Kingdom?")
143-
stream, err := client.ChatStream(ctx, req)
142+
chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{
143+
Role: "user",
144+
Content: "What's the capital of the United Kingdom?",
145+
}}}
146+
147+
stream, err := client.ChatStream(ctx, &chatParams)
144148
require.NoError(t, err)
145149

146150
err = stream.Open()

pkg/providers/openai/chat_test.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,12 @@ func TestOpenAIClient_ChatRequest(t *testing.T) {
5656
client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock())
5757
require.NoError(t, err)
5858

59-
request := schemas.ChatRequest{Message: schemas.ChatMessage{
59+
chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{
6060
Role: "user",
61-
Content: "What's the biggest animal?",
62-
}}
61+
Content: "What's the capital of the United Kingdom?",
62+
}}}
6363

64-
response, err := client.Chat(ctx, &request)
64+
response, err := client.Chat(ctx, &chatParams)
6565
require.NoError(t, err)
6666

6767
require.Equal(t, "chatcmpl-123", response.ID)
@@ -85,12 +85,12 @@ func TestOpenAIClient_RateLimit(t *testing.T) {
8585
client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock())
8686
require.NoError(t, err)
8787

88-
request := schemas.ChatRequest{Message: schemas.ChatMessage{
89-
Role: "user",
88+
chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{
89+
Role: "human",
9090
Content: "What's the biggest animal?",
91-
}}
91+
}}}
9292

93-
_, err = client.Chat(ctx, &request)
93+
_, err = client.Chat(ctx, &chatParams)
9494

9595
require.Error(t, err)
9696
require.IsType(t, &clients.RateLimitError{}, err)

pkg/providers/testing/lang.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ func (c *ProviderMock) SupportChatStream() bool {
119119
return c.supportStreaming
120120
}
121121

122-
func (c *ProviderMock) Chat(_ context.Context, _ *schemas.ChatRequest) (*schemas.ChatResponse, error) {
122+
func (c *ProviderMock) Chat(_ context.Context, _ *schemas.ChatParams) (*schemas.ChatResponse, error) {
123123
if c.chatResps == nil {
124124
return nil, clients.ErrProviderUnavailable
125125
}
@@ -136,7 +136,7 @@ func (c *ProviderMock) Chat(_ context.Context, _ *schemas.ChatRequest) (*schemas
136136
return response.Resp(), nil
137137
}
138138

139-
func (c *ProviderMock) ChatStream(_ context.Context, _ *schemas.ChatStreamRequest) (clients.ChatStream, error) {
139+
func (c *ProviderMock) ChatStream(_ context.Context, _ *schemas.ChatParams) (clients.ChatStream, error) {
140140
if c.chatStreams == nil || c.idx >= len(*c.chatStreams) {
141141
return nil, clients.ErrProviderUnavailable
142142
}

0 commit comments

Comments
 (0)