|
3 | 3 | require "spec_helper" |
4 | 4 |
|
5 | 5 | RSpec.describe Cohere::Client do |
6 | | - let(:instance) { described_class.new(api_key: "123") } |
| 6 | + subject { described_class.new(api_key: "123") } |
7 | 7 |
|
8 | 8 | describe "#generate" do |
9 | 9 | let(:generate_result) { JSON.parse(File.read("spec/fixtures/generate_result.json")) } |
|
16 | 16 | end |
17 | 17 |
|
18 | 18 | it "returns a response" do |
19 | | - expect(instance.generate( |
| 19 | + expect(subject.generate( |
20 | 20 | prompt: "Once upon a time in a magical land called" |
21 | 21 | ).dig("generations").first.dig("text")).to eq(" The Past there was a Game called Warhammer Fantasy Battle.") |
22 | 22 | end |
|
33 | 33 | end |
34 | 34 |
|
35 | 35 | it "returns a response" do |
36 | | - expect(instance.embed( |
| 36 | + expect(subject.embed( |
37 | 37 | texts: ["hello!"] |
38 | 38 | ).dig("embeddings")).to eq([[1.2177734, 0.67529297, 2.0742188]]) |
39 | 39 | end |
40 | 40 | end |
41 | 41 |
|
| 42 | + describe "#rerank" do |
| 43 | + let(:embed_result) { JSON.parse(File.read("spec/fixtures/rerank.json")) } |
| 44 | + let(:response) { OpenStruct.new(body: embed_result) } |
| 45 | + let(:docs) { |
| 46 | + [ |
| 47 | + "Carson City is the capital city of the American state of Nevada.", |
| 48 | + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.", |
| 49 | + "Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.", |
| 50 | + "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.", |
| 51 | + "Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states." |
| 52 | + ] |
| 53 | + } |
| 54 | + |
| 55 | + before do |
| 56 | + allow_any_instance_of(Faraday::Connection).to receive(:post) |
| 57 | + .with("rerank") |
| 58 | + .and_return(response) |
| 59 | + end |
| 60 | + |
| 61 | + it "returns a response" do |
| 62 | + expect( |
| 63 | + subject |
| 64 | + .rerank(query: "What is the capital of the United States?", documents: docs) |
| 65 | + .dig("results") |
| 66 | + .map { |h| h["index"] } |
| 67 | + ).to eq([3, 4, 2, 0, 1]) |
| 68 | + end |
| 69 | + end |
| 70 | + |
42 | 71 | describe "#classify" do |
43 | 72 | let(:classify_result) { JSON.parse(File.read("spec/fixtures/classify_result.json")) } |
44 | 73 | let(:response) { OpenStruct.new(body: classify_result) } |
|
64 | 93 | end |
65 | 94 |
|
66 | 95 | it "returns a response" do |
67 | | - res = instance.classify( |
| 96 | + res = subject.classify( |
68 | 97 | inputs: inputs, |
69 | 98 | examples: examples |
70 | 99 | ).dig("classifications") |
|
85 | 114 | end |
86 | 115 |
|
87 | 116 | it "returns a response" do |
88 | | - expect(instance.tokenize( |
| 117 | + expect(subject.tokenize( |
89 | 118 | text: "Hello, world!" |
90 | 119 | ).dig("tokens")).to eq([33555, 1114, 34]) |
91 | 120 | end |
|
102 | 131 | end |
103 | 132 |
|
104 | 133 | it "returns a response" do |
105 | | - expect(instance.tokenize( |
| 134 | + expect(subject.tokenize( |
106 | 135 | text: "Hello, world!", |
107 | 136 | model: "base" |
108 | 137 | ).dig("tokens")).to eq([33555, 1114, 34]) |
|
120 | 149 | end |
121 | 150 |
|
122 | 151 | it "returns a response" do |
123 | | - expect(instance.detokenize( |
| 152 | + expect(subject.detokenize( |
124 | 153 | tokens: [33555, 1114, 34] |
125 | 154 | ).dig("text")).to eq("hello world!") |
126 | 155 | end |
|
137 | 166 | end |
138 | 167 |
|
139 | 168 | it "returns a response" do |
140 | | - expect(instance.detokenize( |
| 169 | + expect(subject.detokenize( |
141 | 170 | tokens: [33555, 1114, 34], |
142 | 171 | model: "base" |
143 | 172 | ).dig("text")).to eq("hello world!") |
|
155 | 184 | end |
156 | 185 |
|
157 | 186 | it "returns a response" do |
158 | | - expect(instance.detect_language( |
| 187 | + expect(subject.detect_language( |
159 | 188 | texts: ["Здравствуй, Мир"] |
160 | 189 | ).dig("results").first.dig("language_code")).to eq("ru") |
161 | 190 | end |
|
172 | 201 | end |
173 | 202 |
|
174 | 203 | it "returns a response" do |
175 | | - expect(instance.summarize( |
| 204 | + expect(subject.summarize( |
176 | 205 | text: "Ice cream is a sweetened frozen food typically eaten as a snack or dessert. " \ |
177 | 206 | "It may be made from milk or cream and is flavoured with a sweetener, " \ |
178 | 207 | "either sugar or an alternative, and a spice, such as cocoa or vanilla, " \ |
|
0 commit comments