Skip to content

Commit 0756f61

Browse files
specs
1 parent 5d295c8 commit 0756f61

File tree

8 files changed

+240
-12
lines changed

8 files changed

+240
-12
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
## [0.1.0] - 2023-05-23
44

55
- Initial release
6+
- Added `GooglePalmApi::Client` class that supports `embed()`, `generate_text()`, `generate_chat_message()`, `get_model()`, `list_models()` and `count_message_tokens()` methods

lib/google_palm_api/client.rb

Lines changed: 53 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ class Client
88
attr_reader :api_key, :connection
99

1010
ENDPOINT_URL = "https://generativelanguage.googleapis.com/"
11-
AUTOPUSH_ENDPOINT_URL = "https://autopush-generativelanguage.sandbox.googleapis.com/"
1211

1312
DEFAULTS = {
1413
temperature: 0.0,
@@ -51,11 +50,11 @@ def generate_text(
5150
stop_sequences: nil,
5251
client: nil
5352
)
54-
response = connection(url: ENDPOINT_URL).post("/v1beta2/models/#{model || DEFAULTS[:completion_model_name]}:generateText") do |req|
53+
response = connection.post("/v1beta2/models/#{model || DEFAULTS[:completion_model_name]}:generateText") do |req|
5554
req.params = {key: api_key}
5655

5756
req.body = {prompt: { text: prompt }}
58-
req.body[:temperature] = temperature if temperature
57+
req.body[:temperature] = temperature || DEFAULTS[:temperature]
5958
req.body[:candidate_count] = candidate_count if candidate_count
6059
req.body[:max_output_tokens] = max_output_tokens if max_output_tokens
6160
req.body[:top_p] = top_p if top_p
@@ -99,15 +98,15 @@ def generate_chat_message(
9998
client: nil
10099
)
101100
# Overwrite the default ENDPOINT_URL for this method.
102-
response = connection(url: AUTOPUSH_ENDPOINT_URL).post("/v1beta2/models/#{model || DEFAULTS[:chat_completion_model_name]}:generateMessage") do |req|
101+
response = connection.post("/v1beta2/models/#{model || DEFAULTS[:chat_completion_model_name]}:generateMessage") do |req|
103102
req.params = {key: api_key}
104103

105104
req.body = {prompt: { messages: [{content: prompt}] }}
106105
req.body[:model] = model if model
107106
req.body[:context] = context if context
108107
req.body[:examples] = examples if examples
109108
req.body[:messages] = messages if messages
110-
req.body[:temperature] = temperature if temperature
109+
req.body[:temperature] = temperature || DEFAULTS[:temperature]
111110
req.body[:candidate_count] = candidate_count if candidate_count
112111
req.body[:top_p] = top_p if top_p
113112
req.body[:top_k] = top_k if top_k
@@ -133,7 +132,7 @@ def embed(
133132
model: nil,
134133
client: nil
135134
)
136-
response = connection(url: ENDPOINT_URL).post("/v1beta2/models/#{model || DEFAULTS[:embeddings_model_name]}:embedText") do |req|
135+
response = connection.post("/v1beta2/models/#{model || DEFAULTS[:embeddings_model_name]}:embedText") do |req|
137136
req.params = {key: api_key}
138137

139138
req.body = {text: text}
@@ -143,11 +142,57 @@ def embed(
143142
response.body
144143
end
145144

145+
#
146+
# Lists models available through the API.
147+
#
148+
# @param [Integer] page_size
149+
# @param [String] page_token
150+
# @return [Hash]
151+
#
152+
def list_models(page_size: nil, page_token: nil)
153+
response = connection.get("/v1beta2/models") do |req|
154+
req.params = {key: api_key}
155+
156+
req.params[:pageSize] = page_size if page_size
157+
req.params[:pageToken] = page_token if page_token
158+
end
159+
response.body
160+
end
161+
162+
#
163+
# Runs a model's tokenizer on a string and returns the token count.
164+
#
165+
# @param [String] model
166+
# @param [String] prompt
167+
# @return [Hash]
168+
#
169+
def count_message_tokens(model:, prompt:)
170+
response = connection.post("/v1beta2/models/#{model}:countMessageTokens") do |req|
171+
req.params = {key: api_key}
172+
173+
req.body = {prompt: { messages: [{content: prompt}] }}
174+
end
175+
response.body
176+
end
177+
178+
#
179+
# Gets information about a specific Model.
180+
#
181+
# @param [String] name
182+
# @return [Hash]
183+
#
184+
def get_model(model:)
185+
response = connection.get("/v1beta2/models/#{model}") do |req|
186+
req.params = {key: api_key}
187+
end
188+
response.body
189+
end
190+
146191
private
147192

148193
# standard:disable Lint/DuplicateMethods
149-
def connection(url:)
150-
Faraday.new(url: url) do |faraday|
194+
def connection
195+
Faraday.new(url: ENDPOINT_URL) do |faraday|
151196
faraday.request :json
152197
faraday.response :json, content_type: /\bjson$/
153198
faraday.adapter Faraday.default_adapter
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
{
2+
"candidates": [
3+
{
4+
"author": "1",
5+
"content": "Hello! How can I help you today?"
6+
}
7+
],
8+
"messages": [
9+
{
10+
"author": "0",
11+
"content": "Hello!"
12+
}
13+
]
14+
}

spec/fixtures/generate_text.json

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
{
2+
"candidates": [
3+
{
4+
"output": "A man walks into a library and asks for books about paranoia. The librarian whispers, \"They're right behind you!\"",
5+
"safetyRatings": [
6+
{
7+
"category": "HARM_CATEGORY_DEROGATORY",
8+
"probability": "NEGLIGIBLE"
9+
},
10+
{
11+
"category": "HARM_CATEGORY_TOXICITY",
12+
"probability": "NEGLIGIBLE"
13+
},
14+
{
15+
"category": "HARM_CATEGORY_VIOLENCE",
16+
"probability": "NEGLIGIBLE"
17+
},
18+
{
19+
"category": "HARM_CATEGORY_SEXUAL",
20+
"probability": "NEGLIGIBLE"
21+
},
22+
{
23+
"category": "HARM_CATEGORY_MEDICAL",
24+
"probability": "NEGLIGIBLE"
25+
},
26+
{
27+
"category": "HARM_CATEGORY_DANGEROUS",
28+
"probability": "NEGLIGIBLE"
29+
}
30+
]
31+
}
32+
]
33+
}

spec/fixtures/list_models.json

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
{
2+
"models": [
3+
{
4+
"name": "models/chat-bison-001",
5+
"version": "001",
6+
"displayName": "Chat Bison",
7+
"description": "Chat-optimized generative language model.",
8+
"inputTokenLimit": 4096,
9+
"outputTokenLimit": 1024,
10+
"supportedGenerationMethods": ["generateMessage"],
11+
"temperature": 0.25,
12+
"topP": 0.95,
13+
"topK": 40
14+
},
15+
{
16+
"name": "models/text-bison-001",
17+
"version": "001",
18+
"displayName": "Text Bison",
19+
"description": "Model targeted for text generation.",
20+
"inputTokenLimit": 8196,
21+
"outputTokenLimit": 1024,
22+
"supportedGenerationMethods": ["generateText"],
23+
"temperature": 0.7,
24+
"topP": 0.95,
25+
"topK": 40
26+
},
27+
{
28+
"name": "models/embedding-gecko-001",
29+
"version": "001",
30+
"displayName": "Embedding Gecko",
31+
"description": "Obtain a distributed representation of a text.",
32+
"inputTokenLimit": 1024,
33+
"outputTokenLimit": 1,
34+
"supportedGenerationMethods": ["embedText"]
35+
}
36+
]
37+
}
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# frozen_string_literal: true
2+
3+
require "cohere"
4+
5+
RSpec.describe GooglePalmApi::Client do
6+
subject { described_class.new(api_key: "123") }
7+
8+
describe "#list_models" do
9+
let(:fixture) { JSON.parse(File.read("spec/fixtures/list_models.json")) }
10+
let(:response) { OpenStruct.new(body: fixture) }
11+
12+
before do
13+
allow_any_instance_of(Faraday::Connection).to receive(:get)
14+
.with("/v1beta2/models")
15+
.and_return(response)
16+
end
17+
18+
it "returns a list of models" do
19+
expect(subject.list_models.dig("models").count).to eq(3)
20+
end
21+
end
22+
23+
describe "#get_model" do
24+
let(:fixture) { JSON.parse(File.read("spec/fixtures/list_models.json")).dig("models").first }
25+
let(:response) { OpenStruct.new(body: fixture) }
26+
let(:model) { "chat-bison-001" }
27+
28+
before do
29+
allow_any_instance_of(Faraday::Connection).to receive(:get)
30+
.with("/v1beta2/models/#{model}")
31+
.and_return(response)
32+
end
33+
34+
it "returns the model" do
35+
expect(subject.get_model(model: model)).to eq(fixture)
36+
end
37+
end
38+
39+
describe "#count_message_tokens" do
40+
let(:response) { OpenStruct.new(body: {"tokenCount" => 14}) }
41+
let(:model) { "chat-bison-001" }
42+
43+
before do
44+
allow_any_instance_of(Faraday::Connection).to receive(:post)
45+
.with("/v1beta2/models/#{model}:countMessageTokens")
46+
.and_return(response)
47+
end
48+
49+
it "returns the token count" do
50+
expect(subject.count_message_tokens(model: model, prompt: "Hello")).to eq(response.body)
51+
end
52+
end
53+
54+
describe "#embed" do
55+
let(:response) { OpenStruct.new(body: {"embedding" => {"value" => [0.0071609155, 0.010057832, -0.016587045]}}) }
56+
57+
before do
58+
allow_any_instance_of(Faraday::Connection).to receive(:post)
59+
.with("/v1beta2/models/embedding-gecko-001:embedText")
60+
.and_return(response)
61+
end
62+
63+
it "returns the embedding" do
64+
expect(subject.embed(text: "Hello world!")).to eq(response.body)
65+
end
66+
end
67+
68+
describe "#generate_text" do
69+
let(:fixture) { JSON.parse(File.read("spec/fixtures/generate_text.json")) }
70+
let(:response) { OpenStruct.new(body: fixture) }
71+
72+
before do
73+
allow_any_instance_of(Faraday::Connection).to receive(:post)
74+
.with("/v1beta2/models/text-bison-001:generateText")
75+
.and_return(response)
76+
end
77+
78+
it "returns the generated text" do
79+
expect(subject.generate_text(prompt: "Hello")).to eq(fixture)
80+
end
81+
end
82+
83+
describe "#generate_chat_message" do
84+
let(:fixture) { JSON.parse(File.read("spec/fixtures/generate_chat_message.json")) }
85+
let(:response) { OpenStruct.new(body: fixture) }
86+
87+
88+
before do
89+
allow_any_instance_of(Faraday::Connection).to receive(:post)
90+
.with("/v1beta2/models/chat-bison-001:generateMessage")
91+
.and_return(response)
92+
end
93+
94+
it "returns the generated text" do
95+
expect(subject.generate_chat_message(prompt: "Hello!")).to eq(fixture)
96+
end
97+
end
98+
end

spec/google_palm_api_spec.rb

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,4 @@
44
it "has a version number" do
55
expect(GooglePalmApi::VERSION).not_to be nil
66
end
7-
8-
it "does something useful" do
9-
expect(false).to eq(true)
10-
end
117
end

spec/spec_helper.rb

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
# frozen_string_literal: true
22

33
require "google_palm_api"
4+
require "faraday"
5+
require "faraday_middleware"
6+
require "ostruct"
7+
require "json"
48

59
RSpec.configure do |config|
610
# Enable flags like --only-failures and --next-failure

0 commit comments

Comments
 (0)