Skip to content

Commit 6fc75b5

Browse files
add chat() method
1 parent 7144bdb commit 6fc75b5

File tree

2 files changed

+42
-0
lines changed

2 files changed

+42
-0
lines changed

Gemfile.lock

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ GEM
6767
unicode-display_width (2.4.2)
6868

6969
PLATFORMS
70+
arm64-darwin-23
7071
x86_64-darwin-19
7172
x86_64-darwin-21
7273
x86_64-linux

lib/cohere/client.rb

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,47 @@ def initialize(api_key)
1212
@api_key = api_key
1313
end
1414

15+
def chat(
16+
message:,
17+
model: nil,
18+
stream: false,
19+
preamble_override: nil,
20+
chat_history: [],
21+
conversation_id: nil,
22+
prompt_truncation: nil,
23+
connectors: [],
24+
search_queries_only: false,
25+
documents: [],
26+
citation_quality: nil,
27+
temperature: nil,
28+
max_tokens: nil,
29+
k: nil,
30+
p: nil,
31+
frequency_penalty: nil,
32+
presence_penalty: nil
33+
)
34+
response = connection.post("chat") do |req|
35+
req.body = {message: message}
36+
req.body[:model] = model if model
37+
req.body[:stream] = stream if stream
38+
req.body[:preamble_override] = preamble_override if preamble_override
39+
req.body[:chat_history] = chat_history if chat_history
40+
req.body[:conversation_id] = conversation_id if conversation_id
41+
req.body[:prompt_truncation] = prompt_truncation if prompt_truncation
42+
req.body[:connectors] = connectors if connectors
43+
req.body[:search_queries_only] = search_queries_only if search_queries_only
44+
req.body[:documents] = documents if documents
45+
req.body[:citation_quality] = citation_quality if citation_quality
46+
req.body[:temperature] = temperature if temperature
47+
req.body[:max_tokens] = max_tokens if max_tokens
48+
req.body[:k] = k if k
49+
req.body[:p] = p if p
50+
req.body[:frequency_penalty] = frequency_penalty if frequency_penalty
51+
req.body[:presence_penalty] = presence_penalty if presence_penalty
52+
end
53+
response.body
54+
end
55+
1556
# This endpoint generates realistic text conditioned on a given input.
1657
def generate(
1758
prompt:,

0 commit comments

Comments
 (0)