Skip to content

Commit 122f2e1

Browse files
committed
#153: Adjusted the language model interface to use the params
1 parent 9a0e46e commit 122f2e1

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

pkg/providers/lang.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,15 @@ type LangProvider interface {
2222

2323
SupportChatStream() bool
2424

25-
Chat(ctx context.Context, req *schemas.ChatRequest) (*schemas.ChatResponse, error)
26-
ChatStream(ctx context.Context, req *schemas.ChatStreamRequest) (clients.ChatStream, error)
25+
Chat(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error)
26+
ChatStream(ctx context.Context, params *schemas.ChatParams) (clients.ChatStream, error)
2727
}
2828

2929
type LangModel interface {
3030
Model
3131
Provider() string
32-
Chat(ctx context.Context, req *schemas.ChatRequest) (*schemas.ChatResponse, error)
33-
ChatStream(ctx context.Context, req *schemas.ChatStreamRequest) (<-chan *clients.ChatStreamResult, error)
32+
Chat(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error)
33+
ChatStream(ctx context.Context, params *schemas.ChatParams) (<-chan *clients.ChatStreamResult, error)
3434
}
3535

3636
// LanguageModel wraps provider client and expend it with health & latency tracking
@@ -87,10 +87,10 @@ func (m LanguageModel) ChatStreamLatency() *latency.MovingAverage {
8787
return m.chatStreamLatency
8888
}
8989

90-
func (m *LanguageModel) Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) {
90+
func (m *LanguageModel) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) {
9191
startedAt := time.Now()
9292

93-
resp, err := m.client.Chat(ctx, request)
93+
resp, err := m.client.Chat(ctx, params)
9494
if err != nil {
9595
m.healthTracker.TrackErr(err)
9696

@@ -106,8 +106,8 @@ func (m *LanguageModel) Chat(ctx context.Context, request *schemas.ChatRequest)
106106
return resp, err
107107
}
108108

109-
func (m *LanguageModel) ChatStream(ctx context.Context, req *schemas.ChatStreamRequest) (<-chan *clients.ChatStreamResult, error) {
110-
stream, err := m.client.ChatStream(ctx, req)
109+
func (m *LanguageModel) ChatStream(ctx context.Context, params *schemas.ChatParams) (<-chan *clients.ChatStreamResult, error) {
110+
stream, err := m.client.ChatStream(ctx, params)
111111

112112
if err != nil {
113113
m.healthTracker.TrackErr(err)

0 commit comments

Comments
 (0)