Skip to content

Commit 8b25faa

Browse files
Merge pull request #21 from patterns-ai-core/rerank-method
Add rerank method and bump version
2 parents 320d9c8 + 449e39a commit 8b25faa

File tree

7 files changed

+137
-34
lines changed

7 files changed

+137
-34
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
## [Unreleased]
22

3+
## [0.9.11] - 2024-08-01
4+
- New `rerank()` method
5+
36
## [0.9.10] - 2024-05-10
47
- /chat endpoint does not require `message:` parameter anymore
58

Gemfile.lock

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
PATH
22
remote: .
33
specs:
4-
cohere-ruby (0.9.10)
4+
cohere-ruby (0.9.11)
55
faraday (>= 2.0.1, < 3.0)
66

77
GEM

README.md

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Cohere
22

33
<p>
4-
<img alt='Weaviate logo' src='https://static.wikia.nocookie.net/logopedia/images/d/d4/Cohere_2023.svg/revision/latest?cb=20230419182227' height='50' />
4+
<img alt='Cohere logo' src='https://static.wikia.nocookie.net/logopedia/images/d/d4/Cohere_2023.svg/revision/latest?cb=20230419182227' height='50' />
55
+&nbsp;&nbsp;
66
<img alt='Ruby logo' src='https://user-images.githubusercontent.com/541665/230231593-43861278-4550-421d-a543-fd3553aac4f6.png' height='40' />
77
</p>
@@ -42,15 +42,15 @@ client = Cohere::Client.new(
4242

4343
```ruby
4444
client.generate(
45-
prompt: "Once upon a time in a magical land called"
45+
prompt: "Once upon a time in a magical land called"
4646
)
4747
```
4848

4949
### Chat
5050

5151
```ruby
5252
client.chat(
53-
message: "Hey! How are you?"
53+
message: "Hey! How are you?"
5454
)
5555
```
5656

@@ -90,30 +90,45 @@ client.chat(
9090
)
9191
```
9292

93-
94-
9593
### Embed
9694

9795
```ruby
9896
client.embed(
99-
texts: ["hello!"]
97+
texts: ["hello!"]
98+
)
99+
```
100+
101+
### Rerank
102+
103+
```ruby
104+
docs = [
105+
"Carson City is the capital city of the American state of Nevada.",
106+
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
107+
"Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.",
108+
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.",
109+
"Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.",
110+
]
111+
112+
client.rerank(
113+
texts: ["hello!"]
100114
)
101115
```
102116

117+
103118
### Classify
104119

105120
```ruby
106121
examples = [
107-
{ text: "Dermatologists don't like her!", label: "Spam" },
108-
{ text: "Hello, open to this?", label: "Spam" },
109-
{ text: "I need help please wire me $1000 right now", label: "Spam" },
110-
{ text: "Nice to know you ;)", label: "Spam" },
111-
{ text: "Please help me?", label: "Spam" },
112-
{ text: "Your parcel will be delivered today", label: "Not spam" },
113-
{ text: "Review changes to our Terms and Conditions", label: "Not spam" },
114-
{ text: "Weekly sync notes", label: "Not spam" },
115-
{ text: "Re: Follow up from today's meeting", label: "Not spam" },
116-
{ text: "Pre-read for tomorrow", label: "Not spam" }
122+
{ text: "Dermatologists don't like her!", label: "Spam" },
123+
{ text: "Hello, open to this?", label: "Spam" },
124+
{ text: "I need help please wire me $1000 right now", label: "Spam" },
125+
{ text: "Nice to know you ;)", label: "Spam" },
126+
{ text: "Please help me?", label: "Spam" },
127+
{ text: "Your parcel will be delivered today", label: "Not spam" },
128+
{ text: "Review changes to our Terms and Conditions", label: "Not spam" },
129+
{ text: "Weekly sync notes", label: "Not spam" },
130+
{ text: "Re: Follow up from today's meeting", label: "Not spam" },
131+
{ text: "Pre-read for tomorrow", label: "Not spam" }
117132
]
118133

119134
inputs = [
@@ -122,40 +137,40 @@ inputs = [
122137
]
123138

124139
client.classify(
125-
examples: examples,
126-
inputs: inputs
140+
examples: examples,
141+
inputs: inputs
127142
)
128143
```
129144

130145
### Tokenize
131146

132147
```ruby
133148
client.tokenize(
134-
text: "hello world!"
149+
text: "hello world!"
135150
)
136151
```
137152

138153
### Detokenize
139154

140155
```ruby
141156
client.detokenize(
142-
tokens: [33555, 1114 , 34]
157+
tokens: [33555, 1114 , 34]
143158
)
144159
```
145160

146161
### Detect language
147162

148163
```ruby
149164
client.detect_language(
150-
texts: ["Здравствуй, Мир"]
165+
texts: ["Здравствуй, Мир"]
151166
)
152167
```
153168

154169
### Summarize
155170

156171
```ruby
157172
client.summarize(
158-
text: "..."
173+
text: "..."
159174
)
160175
```
161176

lib/cohere/client.rb

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,29 @@ def embed(
119119
response.body
120120
end
121121

122+
def rerank(
123+
query:,
124+
documents:,
125+
model: nil,
126+
top_n: nil,
127+
rank_fields: nil,
128+
return_documents: nil,
129+
max_chunks_per_doc: nil
130+
)
131+
response = connection.post("rerank") do |req|
132+
req.body = {
133+
query: query,
134+
documents: documents
135+
}
136+
req.body[:model] = model if model
137+
req.body[:top_n] = top_n if top_n
138+
req.body[:rank_fields] = rank_fields if rank_fields
139+
req.body[:return_documents] = return_documents if return_documents
140+
req.body[:max_chunks_per_doc] = max_chunks_per_doc if max_chunks_per_doc
141+
end
142+
response.body
143+
end
144+
122145
def classify(
123146
inputs:,
124147
examples:,

lib/cohere/version.rb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# frozen_string_literal: true
22

33
module Cohere
4-
VERSION = "0.9.10"
4+
VERSION = "0.9.11"
55
end

spec/cohere/client_spec.rb

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
require "spec_helper"
44

55
RSpec.describe Cohere::Client do
6-
let(:instance) { described_class.new(api_key: "123") }
6+
subject { described_class.new(api_key: "123") }
77

88
describe "#generate" do
99
let(:generate_result) { JSON.parse(File.read("spec/fixtures/generate_result.json")) }
@@ -16,7 +16,7 @@
1616
end
1717

1818
it "returns a response" do
19-
expect(instance.generate(
19+
expect(subject.generate(
2020
prompt: "Once upon a time in a magical land called"
2121
).dig("generations").first.dig("text")).to eq(" The Past there was a Game called Warhammer Fantasy Battle.")
2222
end
@@ -33,12 +33,41 @@
3333
end
3434

3535
it "returns a response" do
36-
expect(instance.embed(
36+
expect(subject.embed(
3737
texts: ["hello!"]
3838
).dig("embeddings")).to eq([[1.2177734, 0.67529297, 2.0742188]])
3939
end
4040
end
4141

42+
describe "#rerank" do
43+
let(:embed_result) { JSON.parse(File.read("spec/fixtures/rerank.json")) }
44+
let(:response) { OpenStruct.new(body: embed_result) }
45+
let(:docs) {
46+
[
47+
"Carson City is the capital city of the American state of Nevada.",
48+
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
49+
"Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.",
50+
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.",
51+
"Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states."
52+
]
53+
}
54+
55+
before do
56+
allow_any_instance_of(Faraday::Connection).to receive(:post)
57+
.with("rerank")
58+
.and_return(response)
59+
end
60+
61+
it "returns a response" do
62+
expect(
63+
subject
64+
.rerank(query: "What is the capital of the United States?", documents: docs)
65+
.dig("results")
66+
.map { |h| h["index"] }
67+
).to eq([3, 4, 2, 0, 1])
68+
end
69+
end
70+
4271
describe "#classify" do
4372
let(:classify_result) { JSON.parse(File.read("spec/fixtures/classify_result.json")) }
4473
let(:response) { OpenStruct.new(body: classify_result) }
@@ -64,7 +93,7 @@
6493
end
6594

6695
it "returns a response" do
67-
res = instance.classify(
96+
res = subject.classify(
6897
inputs: inputs,
6998
examples: examples
7099
).dig("classifications")
@@ -85,7 +114,7 @@
85114
end
86115

87116
it "returns a response" do
88-
expect(instance.tokenize(
117+
expect(subject.tokenize(
89118
text: "Hello, world!"
90119
).dig("tokens")).to eq([33555, 1114, 34])
91120
end
@@ -102,7 +131,7 @@
102131
end
103132

104133
it "returns a response" do
105-
expect(instance.tokenize(
134+
expect(subject.tokenize(
106135
text: "Hello, world!",
107136
model: "base"
108137
).dig("tokens")).to eq([33555, 1114, 34])
@@ -120,7 +149,7 @@
120149
end
121150

122151
it "returns a response" do
123-
expect(instance.detokenize(
152+
expect(subject.detokenize(
124153
tokens: [33555, 1114, 34]
125154
).dig("text")).to eq("hello world!")
126155
end
@@ -137,7 +166,7 @@
137166
end
138167

139168
it "returns a response" do
140-
expect(instance.detokenize(
169+
expect(subject.detokenize(
141170
tokens: [33555, 1114, 34],
142171
model: "base"
143172
).dig("text")).to eq("hello world!")
@@ -155,7 +184,7 @@
155184
end
156185

157186
it "returns a response" do
158-
expect(instance.detect_language(
187+
expect(subject.detect_language(
159188
texts: ["Здравствуй, Мир"]
160189
).dig("results").first.dig("language_code")).to eq("ru")
161190
end
@@ -172,7 +201,7 @@
172201
end
173202

174203
it "returns a response" do
175-
expect(instance.summarize(
204+
expect(subject.summarize(
176205
text: "Ice cream is a sweetened frozen food typically eaten as a snack or dessert. " \
177206
"It may be made from milk or cream and is flavoured with a sweetener, " \
178207
"either sugar or an alternative, and a spice, such as cocoa or vanilla, " \

spec/fixtures/rerank.json

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
{
2+
"id": "fd2f37a7-78e5-4d43-9230-ca0804f8cab5",
3+
"results": [
4+
{
5+
"index": 3,
6+
"relevance_score": 0.97997653
7+
},
8+
{
9+
"index": 4,
10+
"relevance_score": 0.27963173
11+
},
12+
{
13+
"index": 2,
14+
"relevance_score": 0.10502681
15+
},
16+
{
17+
"index": 0,
18+
"relevance_score": 0.10212547
19+
},
20+
{
21+
"index": 1,
22+
"relevance_score": 0.0721122
23+
}
24+
],
25+
"meta": {
26+
"api_version": {
27+
"version": "1"
28+
},
29+
"billed_units": {
30+
"search_units": 1
31+
}
32+
}
33+
}

0 commit comments

Comments
 (0)