Skip to content

Commit f3d1660

Browse files
fixes
1 parent bc6d716 commit f3d1660

18 files changed

+215
-211
lines changed

.env.example

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
COHERE_API_KEY=

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,5 @@
99

1010
# rspec failure tracking
1111
.rspec_status
12+
13+
.env

Gemfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ gem "rake", "~> 13.0"
99

1010
gem "rspec", "~> 3.0"
1111
gem "standard", "~> 1.28.0"
12+
gem "dotenv", "~> 3.1.4"

Gemfile.lock

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ GEM
99
specs:
1010
ast (2.4.2)
1111
diff-lcs (1.5.0)
12+
dotenv (3.1.4)
1213
faraday (2.7.10)
1314
faraday-net_http (>= 2.0, < 3.1)
1415
ruby2_keywords (>= 0.0.4)
@@ -74,6 +75,7 @@ PLATFORMS
7475

7576
DEPENDENCIES
7677
cohere-ruby!
78+
dotenv (~> 3.1.4)
7779
rake (~> 13.0)
7880
rspec (~> 3.0)
7981
standard (~> 1.28.0)

README.md

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88

99
Cohere API client for Ruby.
1010

11-
Part of the [Langchain.rb](https://github.com/andreibondarev/langchainrb) stack.
11+
Part of the [Langchain.rb](https://github.com/patterns-ai-core/langchainrb) stack.
1212

13-
![Tests status](https://github.com/andreibondarev/cohere-ruby/actions/workflows/ci.yml/badge.svg)
13+
![Tests status](https://github.com/patterns-ai-core/cohere-ruby/actions/workflows/ci.yml/badge.svg)
1414
[![Gem Version](https://badge.fury.io/rb/cohere-ruby.svg)](https://badge.fury.io/rb/cohere-ruby)
1515
[![Docs](http://img.shields.io/badge/yard-docs-blue.svg)](http://rubydoc.info/gems/cohere-ruby)
16-
[![License](https://img.shields.io/badge/license-MIT-green.svg)](https://github.com/andreibondarev/cohere-ruby/blob/main/LICENSE.txt)
16+
[![License](https://img.shields.io/badge/license-MIT-green.svg)](https://github.com/patterns-ai-core/cohere-ruby/blob/main/LICENSE.txt)
1717
[![](https://dcbadge.vercel.app/api/server/WDARp7J2n8?compact=true&style=flat)](https://discord.gg/WDARp7J2n8)
1818

1919
## Installation
@@ -50,14 +50,18 @@ client.generate(
5050

5151
```ruby
5252
client.chat(
53-
message: "Hey! How are you?"
53+
model: "command-r-plus-08-2024",
54+
messages: [{role:"user", content: "Hey! How are you?"}]
5455
)
5556
```
5657

5758
`chat` supports a streaming option. You can pass a block to the `chat` method and it will yield a new chunk as soon as it is received.
5859

5960
```ruby
60-
client.chat(message: "Hey! How are you?", stream: true) do |chunk, overall_received_bytes|
61+
client.chat(
62+
model: "command-r-plus-08-2024",
63+
messages: [{role:"user", content: "Hey! How are you?"}]
64+
) do |chunk, overall_received_bytes|
6165
puts "Received #{overall_received_bytes} bytes: #{chunk.force_encoding(Encoding::UTF_8)}"
6266
end
6367
```
@@ -68,25 +72,25 @@ end
6872

6973
```ruby
7074
tools = [
71-
{
72-
name: "query_daily_sales_report",
73-
description: "Connects to a database to retrieve overall sales volumes and sales information for a given day.",
74-
parameter_definitions: {
75-
day: {
76-
description: "Retrieves sales data for this day, formatted as YYYY-MM-DD.",
77-
type: "str",
78-
required: true
79-
}
80-
}
81-
}
75+
{
76+
name: "query_daily_sales_report",
77+
description: "Connects to a database to retrieve overall sales volumes and sales information for a given day.",
78+
parameter_definitions: {
79+
day: {
80+
description: "Retrieves sales data for this day, formatted as YYYY-MM-DD.",
81+
type: "str",
82+
required: true
83+
}
84+
}
85+
}
8286
]
8387

8488
message = "Can you provide a sales summary for 29th September 2023, and also give me some details about the products in the 'Electronics' category, for example their prices and stock levels?"
8589

8690
client.chat(
8791
model: model,
88-
message: message,
89-
tools: tools,
92+
messages: [{ role:"user", content: message }],
93+
tools: tools
9094
)
9195
```
9296

bin/console

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,7 @@
33

44
require "bundler/setup"
55
require "cohere"
6-
7-
# You can add fixtures and/or initialization code here to make experimenting
8-
# with your gem easier. You can also use a different console, if you like.
9-
10-
# (If you use this, don't forget to add pry to your Gemfile!)
11-
# require "pry"
12-
# Pry.start
6+
require "dotenv/load"
137

148
client = Cohere::Client.new(
159
api_key: ENV['COHERE_API_KEY']

lib/cohere/client.rb

Lines changed: 79 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -6,62 +6,56 @@ module Cohere
66
class Client
77
attr_reader :api_key, :connection
88

9-
ENDPOINT_URL = "https://api.cohere.com/v2"
10-
119
def initialize(api_key:, timeout: nil)
1210
@api_key = api_key
1311
@timeout = timeout
1412
end
1513

14+
# Generates a text response to a user message and streams it down, token by token
1615
def chat(
17-
message: nil,
18-
model: nil,
16+
model:,
17+
messages:,
1918
stream: false,
20-
preamble: nil,
21-
preamble_override: nil,
22-
chat_history: [],
23-
conversation_id: nil,
24-
prompt_truncation: nil,
25-
connectors: [],
26-
search_queries_only: false,
19+
tools: [],
2720
documents: [],
28-
citation_quality: nil,
29-
temperature: nil,
21+
citation_options: nil,
22+
response_format: nil,
23+
safety_mode: nil,
3024
max_tokens: nil,
31-
k: nil,
32-
p: nil,
25+
stop_sequences: nil,
26+
temperature: nil,
3327
seed: nil,
3428
frequency_penalty: nil,
3529
presence_penalty: nil,
36-
tools: [],
30+
k: nil,
31+
p: nil,
32+
logprops: nil,
3733
&block
3834
)
39-
response = connection.post("chat") do |req|
35+
response = v2_connection.post("chat") do |req|
4036
req.body = {}
4137

42-
req.body[:message] = message if message
43-
req.body[:model] = model if model
44-
if stream || block
45-
req.body[:stream] = true
46-
req.options.on_data = block if block
47-
end
48-
req.body[:preamble] = preamble if preamble
49-
req.body[:preamble_override] = preamble_override if preamble_override
50-
req.body[:chat_history] = chat_history if chat_history
51-
req.body[:conversation_id] = conversation_id if conversation_id
52-
req.body[:prompt_truncation] = prompt_truncation if prompt_truncation
53-
req.body[:connectors] = connectors if connectors
54-
req.body[:search_queries_only] = search_queries_only if search_queries_only
55-
req.body[:documents] = documents if documents
56-
req.body[:citation_quality] = citation_quality if citation_quality
57-
req.body[:temperature] = temperature if temperature
38+
req.body[:model] = model
39+
req.body[:messages] = messages if messages
40+
req.body[:tools] = tools if tools.any?
41+
req.body[:documents] = documents if documents.any?
42+
req.body[:citation_options] = citation_options if citation_options
43+
req.body[:response_format] = response_format if response_format
44+
req.body[:safety_mode] = safety_mode if safety_mode
5845
req.body[:max_tokens] = max_tokens if max_tokens
59-
req.body[:k] = k if k
60-
req.body[:p] = p if p
46+
req.body[:stop_sequences] = stop_sequences if stop_sequences
47+
req.body[:temperature] = temperature if temperature
6148
req.body[:seed] = seed if seed
6249
req.body[:frequency_penalty] = frequency_penalty if frequency_penalty
6350
req.body[:presence_penalty] = presence_penalty if presence_penalty
64-
req.body[:tools] = tools if tools
51+
req.body[:k] = k if k
52+
req.body[:p] = p if p
53+
req.body[:logprops] = logprops if logprops
54+
55+
if stream || block
56+
req.body[:stream] = true
57+
req.options.on_data = block if block
58+
end
6559
end
6660
response.body
6761
end
@@ -104,36 +98,44 @@ def generate(
10498
response.body
10599
end
106100

101+
# This endpoint returns text embeddings. An embedding is a list of floating point numbers that captures semantic information about the text that it represents.
107102
def embed(
108-
texts:,
109-
model: nil,
110-
input_type: nil,
103+
model:,
104+
input_type:,
105+
embedding_types:,
106+
texts: nil,
107+
images: nil,
111108
truncate: nil
112109
)
113-
response = connection.post("embed") do |req|
114-
req.body = {texts: texts}
115-
req.body[:model] = model if model
116-
req.body[:input_type] = input_type if input_type
110+
response = v2_connection.post("embed") do |req|
111+
req.body = {
112+
model: model,
113+
input_type: input_type,
114+
embedding_types: embedding_types
115+
}
116+
req.body[:texts] = texts if texts
117+
req.body[:images] = images if images
117118
req.body[:truncate] = truncate if truncate
118119
end
119120
response.body
120121
end
121122

123+
# This endpoint takes in a query and a list of texts and produces an ordered array with each text assigned a relevance score.
122124
def rerank(
125+
model:,
123126
query:,
124127
documents:,
125-
model: nil,
126128
top_n: nil,
127129
rank_fields: nil,
128130
return_documents: nil,
129131
max_chunks_per_doc: nil
130132
)
131-
response = connection.post("rerank") do |req|
133+
response = v2_connection.post("rerank") do |req|
132134
req.body = {
135+
model: model,
133136
query: query,
134137
documents: documents
135138
}
136-
req.body[:model] = model if model
137139
req.body[:top_n] = top_n if top_n
138140
req.body[:rank_fields] = rank_fields if rank_fields
139141
req.body[:return_documents] = return_documents if return_documents
@@ -142,41 +144,44 @@ def rerank(
142144
response.body
143145
end
144146

147+
# This endpoint makes a prediction about which label fits the specified text inputs best.
145148
def classify(
149+
model:,
146150
inputs:,
147-
examples:,
148-
model: nil,
149-
present: nil,
151+
examples: nil,
152+
preset: nil,
150153
truncate: nil
151154
)
152-
response = connection.post("classify") do |req|
155+
response = v1_connection.post("classify") do |req|
153156
req.body = {
154-
inputs: inputs,
155-
examples: examples
157+
model: model,
158+
inputs: inputs
156159
}
157-
req.body[:model] = model if model
158-
req.body[:present] = present if present
160+
req.body[:examples] = examples if examples
161+
req.body[:preset] = preset if preset
159162
req.body[:truncate] = truncate if truncate
160163
end
161164
response.body
162165
end
163166

164-
def tokenize(text:, model: nil)
165-
response = connection.post("tokenize") do |req|
166-
req.body = model.nil? ? {text: text} : {text: text, model: model}
167+
# This endpoint splits input text into smaller units called tokens using byte-pair encoding (BPE).
168+
def tokenize(text:, model:)
169+
response = v1_connection.post("tokenize") do |req|
170+
req.body = {text: text, model: model}
167171
end
168172
response.body
169173
end
170174

171-
def detokenize(tokens:, model: nil)
172-
response = connection.post("detokenize") do |req|
173-
req.body = model.nil? ? {tokens: tokens} : {tokens: tokens, model: model}
175+
# This endpoint takes tokens using byte-pair encoding and returns their text representation.
176+
def detokenize(tokens:, model:)
177+
response = v1_connection.post("detokenize") do |req|
178+
req.body = {tokens: tokens, model: model}
174179
end
175180
response.body
176181
end
177182

178183
def detect_language(texts:)
179-
response = connection.post("detect-language") do |req|
184+
response = v1_connection.post("detect-language") do |req|
180185
req.body = {texts: texts}
181186
end
182187
response.body
@@ -191,7 +196,7 @@ def summarize(
191196
temperature: nil,
192197
additional_command: nil
193198
)
194-
response = connection.post("summarize") do |req|
199+
response = v1_connection.post("summarize") do |req|
195200
req.body = {text: text}
196201
req.body[:length] = length if length
197202
req.body[:format] = format if format
@@ -205,17 +210,22 @@ def summarize(
205210

206211
private
207212

208-
# standard:disable Lint/DuplicateMethods
209-
def connection
210-
@connection ||= Faraday.new(url: ENDPOINT_URL, request: {timeout: @timeout}) do |faraday|
211-
if api_key
212-
faraday.request :authorization, :Bearer, api_key
213-
end
213+
def v1_connection
214+
@connection ||= Faraday.new(url: "https://api.cohere.ai/v1", request: {timeout: @timeout}) do |faraday|
215+
faraday.request :authorization, :Bearer, api_key
216+
faraday.request :json
217+
faraday.response :json, content_type: /\bjson$/
218+
faraday.adapter Faraday.default_adapter
219+
end
220+
end
221+
222+
def v2_connection
223+
@connection ||= Faraday.new(url: "https://api.cohere.com/v2", request: {timeout: @timeout}) do |faraday|
224+
faraday.request :authorization, :Bearer, api_key
214225
faraday.request :json
215226
faraday.response :json, content_type: /\bjson$/
216227
faraday.adapter Faraday.default_adapter
217228
end
218229
end
219-
# standard:enable Lint/DuplicateMethods
220230
end
221231
end

0 commit comments

Comments
 (0)