Skip to content

Commit b1c4fdb

Browse files
Add rerank method and bump version
1 parent 320d9c8 commit b1c4fdb

File tree

7 files changed

+136
-34
lines changed

7 files changed

+136
-34
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
## [Unreleased]
22

3+
## [0.9.11] - 2024-08-01
4+
- New `rerank()` method
5+
36
## [0.9.10] - 2024-05-10
47
- /chat endpoint does not require `message:` parameter anymore
58

Gemfile.lock

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
PATH
22
remote: .
33
specs:
4-
cohere-ruby (0.9.10)
4+
cohere-ruby (0.9.11)
55
faraday (>= 2.0.1, < 3.0)
66

77
GEM

README.md

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Cohere
22

33
<p>
4-
<img alt='Weaviate logo' src='https://static.wikia.nocookie.net/logopedia/images/d/d4/Cohere_2023.svg/revision/latest?cb=20230419182227' height='50' />
4+
<img alt='Cohere logo' src='https://static.wikia.nocookie.net/logopedia/images/d/d4/Cohere_2023.svg/revision/latest?cb=20230419182227' height='50' />
55
+&nbsp;&nbsp;
66
<img alt='Ruby logo' src='https://user-images.githubusercontent.com/541665/230231593-43861278-4550-421d-a543-fd3553aac4f6.png' height='40' />
77
</p>
@@ -42,15 +42,15 @@ client = Cohere::Client.new(
4242

4343
```ruby
4444
client.generate(
45-
prompt: "Once upon a time in a magical land called"
45+
prompt: "Once upon a time in a magical land called"
4646
)
4747
```
4848

4949
### Chat
5050

5151
```ruby
5252
client.chat(
53-
message: "Hey! How are you?"
53+
message: "Hey! How are you?"
5454
)
5555
```
5656

@@ -90,30 +90,45 @@ client.chat(
9090
)
9191
```
9292

93-
94-
9593
### Embed
9694

9795
```ruby
9896
client.embed(
99-
texts: ["hello!"]
97+
texts: ["hello!"]
98+
)
99+
```
100+
101+
### Rerank
102+
103+
```ruby
104+
docs = [
105+
"Carson City is the capital city of the American state of Nevada.",
106+
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
107+
"Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.",
108+
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.",
109+
"Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.",
110+
]
111+
112+
client.rerank(
113+
texts: ["hello!"]
100114
)
101115
```
102116

117+
103118
### Classify
104119

105120
```ruby
106121
examples = [
107-
{ text: "Dermatologists don't like her!", label: "Spam" },
108-
{ text: "Hello, open to this?", label: "Spam" },
109-
{ text: "I need help please wire me $1000 right now", label: "Spam" },
110-
{ text: "Nice to know you ;)", label: "Spam" },
111-
{ text: "Please help me?", label: "Spam" },
112-
{ text: "Your parcel will be delivered today", label: "Not spam" },
113-
{ text: "Review changes to our Terms and Conditions", label: "Not spam" },
114-
{ text: "Weekly sync notes", label: "Not spam" },
115-
{ text: "Re: Follow up from today's meeting", label: "Not spam" },
116-
{ text: "Pre-read for tomorrow", label: "Not spam" }
122+
{ text: "Dermatologists don't like her!", label: "Spam" },
123+
{ text: "Hello, open to this?", label: "Spam" },
124+
{ text: "I need help please wire me $1000 right now", label: "Spam" },
125+
{ text: "Nice to know you ;)", label: "Spam" },
126+
{ text: "Please help me?", label: "Spam" },
127+
{ text: "Your parcel will be delivered today", label: "Not spam" },
128+
{ text: "Review changes to our Terms and Conditions", label: "Not spam" },
129+
{ text: "Weekly sync notes", label: "Not spam" },
130+
{ text: "Re: Follow up from today's meeting", label: "Not spam" },
131+
{ text: "Pre-read for tomorrow", label: "Not spam" }
117132
]
118133

119134
inputs = [
@@ -122,40 +137,40 @@ inputs = [
122137
]
123138

124139
client.classify(
125-
examples: examples,
126-
inputs: inputs
140+
examples: examples,
141+
inputs: inputs
127142
)
128143
```
129144

130145
### Tokenize
131146

132147
```ruby
133148
client.tokenize(
134-
text: "hello world!"
149+
text: "hello world!"
135150
)
136151
```
137152

138153
### Detokenize
139154

140155
```ruby
141156
client.detokenize(
142-
tokens: [33555, 1114 , 34]
157+
tokens: [33555, 1114 , 34]
143158
)
144159
```
145160

146161
### Detect language
147162

148163
```ruby
149164
client.detect_language(
150-
texts: ["Здравствуй, Мир"]
165+
texts: ["Здравствуй, Мир"]
151166
)
152167
```
153168

154169
### Summarize
155170

156171
```ruby
157172
client.summarize(
158-
text: "..."
173+
text: "..."
159174
)
160175
```
161176

lib/cohere/client.rb

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,30 @@ def embed(
119119
response.body
120120
end
121121

122+
def rerank(
123+
query:,
124+
documents:,
125+
model: nil,
126+
top_n: nil,
127+
rank_fields: nil,
128+
return_documents: nil,
129+
max_chunks_per_doc: nil
130+
)
131+
response = connection.post("rerank") do |req|
132+
req.body = {
133+
query: query,
134+
documents: documents
135+
}
136+
req.body[:model] = model if model
137+
req.body[:top_n] = top_n if top_n
138+
req.body[:rank_fields] = rank_fields if rank_fields
139+
req.body[:return_documents] = return_documents if return_documents
140+
req.body[:max_chunks_per_doc] = max_chunks_per_doc if max_chunks_per_doc
141+
end
142+
response.body
143+
144+
end
145+
122146
def classify(
123147
inputs:,
124148
examples:,

lib/cohere/version.rb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# frozen_string_literal: true
22

33
module Cohere
4-
VERSION = "0.9.10"
4+
VERSION = "0.9.11"
55
end

spec/cohere/client_spec.rb

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
require "spec_helper"
44

55
RSpec.describe Cohere::Client do
6-
let(:instance) { described_class.new(api_key: "123") }
6+
subject { described_class.new(api_key: "123") }
77

88
describe "#generate" do
99
let(:generate_result) { JSON.parse(File.read("spec/fixtures/generate_result.json")) }
@@ -16,7 +16,7 @@
1616
end
1717

1818
it "returns a response" do
19-
expect(instance.generate(
19+
expect(subject.generate(
2020
prompt: "Once upon a time in a magical land called"
2121
).dig("generations").first.dig("text")).to eq(" The Past there was a Game called Warhammer Fantasy Battle.")
2222
end
@@ -33,12 +33,39 @@
3333
end
3434

3535
it "returns a response" do
36-
expect(instance.embed(
36+
expect(subject.embed(
3737
texts: ["hello!"]
3838
).dig("embeddings")).to eq([[1.2177734, 0.67529297, 2.0742188]])
3939
end
4040
end
4141

42+
describe "#rerank" do
43+
let(:embed_result) { JSON.parse(File.read("spec/fixtures/rerank.json")) }
44+
let(:response) { OpenStruct.new(body: embed_result) }
45+
let(:docs) {[
46+
"Carson City is the capital city of the American state of Nevada.",
47+
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
48+
"Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.",
49+
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.",
50+
"Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.",
51+
]}
52+
53+
before do
54+
allow_any_instance_of(Faraday::Connection).to receive(:post)
55+
.with("rerank")
56+
.and_return(response)
57+
end
58+
59+
it "returns a response" do
60+
expect(
61+
subject
62+
.rerank(query: "What is the capital of the United States?", documents: docs)
63+
.dig("results")
64+
.map {|h| h["index"]}
65+
).to eq([3, 4, 2, 0, 1])
66+
end
67+
end
68+
4269
describe "#classify" do
4370
let(:classify_result) { JSON.parse(File.read("spec/fixtures/classify_result.json")) }
4471
let(:response) { OpenStruct.new(body: classify_result) }
@@ -64,7 +91,7 @@
6491
end
6592

6693
it "returns a response" do
67-
res = instance.classify(
94+
res = subject.classify(
6895
inputs: inputs,
6996
examples: examples
7097
).dig("classifications")
@@ -85,7 +112,7 @@
85112
end
86113

87114
it "returns a response" do
88-
expect(instance.tokenize(
115+
expect(subject.tokenize(
89116
text: "Hello, world!"
90117
).dig("tokens")).to eq([33555, 1114, 34])
91118
end
@@ -102,7 +129,7 @@
102129
end
103130

104131
it "returns a response" do
105-
expect(instance.tokenize(
132+
expect(subject.tokenize(
106133
text: "Hello, world!",
107134
model: "base"
108135
).dig("tokens")).to eq([33555, 1114, 34])
@@ -120,7 +147,7 @@
120147
end
121148

122149
it "returns a response" do
123-
expect(instance.detokenize(
150+
expect(subject.detokenize(
124151
tokens: [33555, 1114, 34]
125152
).dig("text")).to eq("hello world!")
126153
end
@@ -137,7 +164,7 @@
137164
end
138165

139166
it "returns a response" do
140-
expect(instance.detokenize(
167+
expect(subject.detokenize(
141168
tokens: [33555, 1114, 34],
142169
model: "base"
143170
).dig("text")).to eq("hello world!")
@@ -155,7 +182,7 @@
155182
end
156183

157184
it "returns a response" do
158-
expect(instance.detect_language(
185+
expect(subject.detect_language(
159186
texts: ["Здравствуй, Мир"]
160187
).dig("results").first.dig("language_code")).to eq("ru")
161188
end
@@ -172,7 +199,7 @@
172199
end
173200

174201
it "returns a response" do
175-
expect(instance.summarize(
202+
expect(subject.summarize(
176203
text: "Ice cream is a sweetened frozen food typically eaten as a snack or dessert. " \
177204
"It may be made from milk or cream and is flavoured with a sweetener, " \
178205
"either sugar or an alternative, and a spice, such as cocoa or vanilla, " \

spec/fixtures/rerank.json

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
{
2+
"id": "fd2f37a7-78e5-4d43-9230-ca0804f8cab5",
3+
"results": [
4+
{
5+
"index": 3,
6+
"relevance_score": 0.97997653
7+
},
8+
{
9+
"index": 4,
10+
"relevance_score": 0.27963173
11+
},
12+
{
13+
"index": 2,
14+
"relevance_score": 0.10502681
15+
},
16+
{
17+
"index": 0,
18+
"relevance_score": 0.10212547
19+
},
20+
{
21+
"index": 1,
22+
"relevance_score": 0.0721122
23+
}
24+
],
25+
"meta": {
26+
"api_version": {
27+
"version": "1"
28+
},
29+
"billed_units": {
30+
"search_units": 1
31+
}
32+
}
33+
}

0 commit comments

Comments
 (0)