|
3 | 3 | require "spec_helper" |
4 | 4 |
|
5 | 5 | RSpec.describe Cohere::Client do |
6 | | - let(:instance) { described_class.new(api_key: "123") } |
| 6 | + subject { described_class.new(api_key: "123") } |
7 | 7 |
|
8 | 8 | describe "#generate" do |
9 | 9 | let(:generate_result) { JSON.parse(File.read("spec/fixtures/generate_result.json")) } |
|
16 | 16 | end |
17 | 17 |
|
18 | 18 | it "returns a response" do |
19 | | - expect(instance.generate( |
| 19 | + expect(subject.generate( |
20 | 20 | prompt: "Once upon a time in a magical land called" |
21 | 21 | ).dig("generations").first.dig("text")).to eq(" The Past there was a Game called Warhammer Fantasy Battle.") |
22 | 22 | end |
|
33 | 33 | end |
34 | 34 |
|
35 | 35 | it "returns a response" do |
36 | | - expect(instance.embed( |
| 36 | + expect(subject.embed( |
37 | 37 | texts: ["hello!"] |
38 | 38 | ).dig("embeddings")).to eq([[1.2177734, 0.67529297, 2.0742188]]) |
39 | 39 | end |
40 | 40 | end |
41 | 41 |
|
| 42 | + describe "#rerank" do |
| 43 | + let(:embed_result) { JSON.parse(File.read("spec/fixtures/rerank.json")) } |
| 44 | + let(:response) { OpenStruct.new(body: embed_result) } |
| 45 | + let(:docs) {[ |
| 46 | + "Carson City is the capital city of the American state of Nevada.", |
| 47 | + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.", |
| 48 | + "Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.", |
| 49 | + "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.", |
| 50 | + "Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.", |
| 51 | + ]} |
| 52 | + |
| 53 | + before do |
| 54 | + allow_any_instance_of(Faraday::Connection).to receive(:post) |
| 55 | + .with("rerank") |
| 56 | + .and_return(response) |
| 57 | + end |
| 58 | + |
| 59 | + it "returns a response" do |
| 60 | + expect( |
| 61 | + subject |
| 62 | + .rerank(query: "What is the capital of the United States?", documents: docs) |
| 63 | + .dig("results") |
| 64 | + .map {|h| h["index"]} |
| 65 | + ).to eq([3, 4, 2, 0, 1]) |
| 66 | + end |
| 67 | + end |
| 68 | + |
42 | 69 | describe "#classify" do |
43 | 70 | let(:classify_result) { JSON.parse(File.read("spec/fixtures/classify_result.json")) } |
44 | 71 | let(:response) { OpenStruct.new(body: classify_result) } |
|
64 | 91 | end |
65 | 92 |
|
66 | 93 | it "returns a response" do |
67 | | - res = instance.classify( |
| 94 | + res = subject.classify( |
68 | 95 | inputs: inputs, |
69 | 96 | examples: examples |
70 | 97 | ).dig("classifications") |
|
85 | 112 | end |
86 | 113 |
|
87 | 114 | it "returns a response" do |
88 | | - expect(instance.tokenize( |
| 115 | + expect(subject.tokenize( |
89 | 116 | text: "Hello, world!" |
90 | 117 | ).dig("tokens")).to eq([33555, 1114, 34]) |
91 | 118 | end |
|
102 | 129 | end |
103 | 130 |
|
104 | 131 | it "returns a response" do |
105 | | - expect(instance.tokenize( |
| 132 | + expect(subject.tokenize( |
106 | 133 | text: "Hello, world!", |
107 | 134 | model: "base" |
108 | 135 | ).dig("tokens")).to eq([33555, 1114, 34]) |
|
120 | 147 | end |
121 | 148 |
|
122 | 149 | it "returns a response" do |
123 | | - expect(instance.detokenize( |
| 150 | + expect(subject.detokenize( |
124 | 151 | tokens: [33555, 1114, 34] |
125 | 152 | ).dig("text")).to eq("hello world!") |
126 | 153 | end |
|
137 | 164 | end |
138 | 165 |
|
139 | 166 | it "returns a response" do |
140 | | - expect(instance.detokenize( |
| 167 | + expect(subject.detokenize( |
141 | 168 | tokens: [33555, 1114, 34], |
142 | 169 | model: "base" |
143 | 170 | ).dig("text")).to eq("hello world!") |
|
155 | 182 | end |
156 | 183 |
|
157 | 184 | it "returns a response" do |
158 | | - expect(instance.detect_language( |
| 185 | + expect(subject.detect_language( |
159 | 186 | texts: ["Здравствуй, Мир"] |
160 | 187 | ).dig("results").first.dig("language_code")).to eq("ru") |
161 | 188 | end |
|
172 | 199 | end |
173 | 200 |
|
174 | 201 | it "returns a response" do |
175 | | - expect(instance.summarize( |
| 202 | + expect(subject.summarize( |
176 | 203 | text: "Ice cream is a sweetened frozen food typically eaten as a snack or dessert. " \ |
177 | 204 | "It may be made from milk or cream and is flavoured with a sweetener, " \ |
178 | 205 | "either sugar or an alternative, and a spice, such as cocoa or vanilla, " \ |
|
0 commit comments