Skip to content

Commit 47326f1

Browse files
Merge pull request #1 from andreibondarev/fixes
fixes
2 parents 0756f61 + 6179ee1 commit 47326f1

File tree

5 files changed

+46
-34
lines changed

5 files changed

+46
-34
lines changed

Gemfile.lock

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ GEM
100100

101101
PLATFORMS
102102
x86_64-darwin-19
103+
x86_64-linux
103104

104105
DEPENDENCIES
105106
google_palm_api!

README.md

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,37 @@ Welcome to your new gem! In this directory, you'll find the files you need to be
66

77
## Installation
88

9-
TODO: Replace `UPDATE_WITH_YOUR_GEM_NAME_PRIOR_TO_RELEASE_TO_RUBYGEMS_ORG` with your gem name right after releasing it to RubyGems.org. Please do not do it earlier due to security reasons. Alternatively, replace this section with instructions to install your gem from git if you don't plan to release to RubyGems.org.
10-
119
Install the gem and add to the application's Gemfile by executing:
1210

13-
$ bundle add UPDATE_WITH_YOUR_GEM_NAME_PRIOR_TO_RELEASE_TO_RUBYGEMS_ORG
11+
$ bundle add google_palm_api
1412

1513
If bundler is not being used to manage dependencies, install the gem by executing:
1614

17-
$ gem install UPDATE_WITH_YOUR_GEM_NAME_PRIOR_TO_RELEASE_TO_RUBYGEMS_ORG
15+
$ gem install google_palm_api
1816

1917
## Usage
2018

21-
TODO: Write usage instructions here
22-
19+
```ruby
20+
require 'google_palm_api'
21+
```
22+
```ruby
23+
client.generate_text(prompt:)
24+
```
25+
```ruby
26+
client.generate_chat_message(prompt:)
27+
```
28+
```ruby
29+
client.embed(text:)
30+
```
31+
```ruby
32+
client.get_model(model:)
33+
```
34+
```ruby
35+
client.list_models()
36+
```
37+
```ruby
38+
client.count_message_tokens(message:)
39+
```
2340
## Development
2441

2542
After checking out the repo, run `bin/setup` to install dependencies. Then, run `rake spec` to run the tests. You can also run `bin/console` for an interactive prompt that will allow you to experiment.

lib/google_palm_api.rb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@
44

55
module GooglePalmApi
66
class Error < StandardError; end
7-
7+
88
autoload :Client, "google_palm_api/client"
99
end

lib/google_palm_api/client.rb

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class Client
1313
temperature: 0.0,
1414
completion_model_name: "text-bison-001",
1515
chat_completion_model_name: "chat-bison-001",
16-
embeddings_model_name: "embedding-gecko-001",
16+
embeddings_model_name: "embedding-gecko-001"
1717
}
1818

1919
def initialize(api_key:)
@@ -40,7 +40,6 @@ def initialize(api_key:)
4040
#
4141
def generate_text(
4242
prompt:,
43-
model: nil,
4443
temperature: nil,
4544
candidate_count: nil,
4645
max_output_tokens: nil,
@@ -50,10 +49,10 @@ def generate_text(
5049
stop_sequences: nil,
5150
client: nil
5251
)
53-
response = connection.post("/v1beta2/models/#{model || DEFAULTS[:completion_model_name]}:generateText") do |req|
52+
response = connection.post("/v1beta2/models/#{DEFAULTS[:completion_model_name]}:generateText") do |req|
5453
req.params = {key: api_key}
5554

56-
req.body = {prompt: { text: prompt }}
55+
req.body = {prompt: {text: prompt}}
5756
req.body[:temperature] = temperature || DEFAULTS[:temperature]
5857
req.body[:candidate_count] = candidate_count if candidate_count
5958
req.body[:max_output_tokens] = max_output_tokens if max_output_tokens
@@ -87,7 +86,6 @@ def generate_text(
8786
#
8887
def generate_chat_message(
8988
prompt:,
90-
model: nil,
9189
context: nil,
9290
examples: nil,
9391
messages: nil,
@@ -98,11 +96,10 @@ def generate_chat_message(
9896
client: nil
9997
)
10098
# Overwrite the default ENDPOINT_URL for this method.
101-
response = connection.post("/v1beta2/models/#{model || DEFAULTS[:chat_completion_model_name]}:generateMessage") do |req|
99+
response = connection.post("/v1beta2/models/#{DEFAULTS[:chat_completion_model_name]}:generateMessage") do |req|
102100
req.params = {key: api_key}
103101

104-
req.body = {prompt: { messages: [{content: prompt}] }}
105-
req.body[:model] = model if model
102+
req.body = {prompt: {messages: [{content: prompt}]}}
106103
req.body[:context] = context if context
107104
req.body[:examples] = examples if examples
108105
req.body[:messages] = messages if messages
@@ -113,7 +110,7 @@ def generate_chat_message(
113110
req.body[:client] = client if client
114111
end
115112
response.body
116-
end
113+
end
117114

118115
#
119116
# The embedding service in the PaLM API generates state-of-the-art embeddings for words, phrases, and sentences.
@@ -144,49 +141,49 @@ def embed(
144141

145142
#
146143
# Lists models available through the API.
147-
#
144+
#
148145
# @param [Integer] page_size
149146
# @param [String] page_token
150147
# @return [Hash]
151-
#
148+
#
152149
def list_models(page_size: nil, page_token: nil)
153150
response = connection.get("/v1beta2/models") do |req|
154151
req.params = {key: api_key}
155152

156153
req.params[:pageSize] = page_size if page_size
157154
req.params[:pageToken] = page_token if page_token
158155
end
159-
response.body
156+
response.body
160157
end
161158

162-
#
159+
#
163160
# Runs a model's tokenizer on a string and returns the token count.
164-
#
161+
#
165162
# @param [String] model
166163
# @param [String] prompt
167164
# @return [Hash]
168-
#
165+
#
169166
def count_message_tokens(model:, prompt:)
170167
response = connection.post("/v1beta2/models/#{model}:countMessageTokens") do |req|
171168
req.params = {key: api_key}
172169

173-
req.body = {prompt: { messages: [{content: prompt}] }}
170+
req.body = {prompt: {messages: [{content: prompt}]}}
174171
end
175172
response.body
176173
end
177174

178175
#
179176
# Gets information about a specific Model.
180-
#
177+
#
181178
# @param [String] name
182179
# @return [Hash]
183-
#
180+
#
184181
def get_model(model:)
185182
response = connection.get("/v1beta2/models/#{model}") do |req|
186183
req.params = {key: api_key}
187184
end
188185
response.body
189-
end
186+
end
190187

191188
private
192189

spec/google_palm_api/client_spec.rb

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
# frozen_string_literal: true
22

3-
require "cohere"
4-
53
RSpec.describe GooglePalmApi::Client do
64
subject { described_class.new(api_key: "123") }
75

@@ -71,8 +69,8 @@
7169

7270
before do
7371
allow_any_instance_of(Faraday::Connection).to receive(:post)
74-
.with("/v1beta2/models/text-bison-001:generateText")
75-
.and_return(response)
72+
.with("/v1beta2/models/text-bison-001:generateText")
73+
.and_return(response)
7674
end
7775

7876
it "returns the generated text" do
@@ -84,15 +82,14 @@
8482
let(:fixture) { JSON.parse(File.read("spec/fixtures/generate_chat_message.json")) }
8583
let(:response) { OpenStruct.new(body: fixture) }
8684

87-
8885
before do
8986
allow_any_instance_of(Faraday::Connection).to receive(:post)
90-
.with("/v1beta2/models/chat-bison-001:generateMessage")
91-
.and_return(response)
87+
.with("/v1beta2/models/chat-bison-001:generateMessage")
88+
.and_return(response)
9289
end
9390

9491
it "returns the generated text" do
9592
expect(subject.generate_chat_message(prompt: "Hello!")).to eq(fixture)
96-
end
93+
end
9794
end
9895
end

0 commit comments

Comments
 (0)