Skip to content

Commit 1139ba3

Browse files
committed
Refactored Streaming functionality globally, refactored Gemini provider, introduced streaming error parsing for Anthropic
1 parent 5d68cc4 commit 1139ba3

File tree

50 files changed

+1144
-1164
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+1144
-1164
lines changed

lib/ruby_llm/error.rb

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,21 @@ def initialize(response = nil, message = nil)
1919
end
2020
end
2121

22-
class ModelNotFoundError < StandardError; end
22+
# Error classes for non-HTTP errors
23+
class ConfigurationError < StandardError; end
2324
class InvalidRoleError < StandardError; end
25+
class ModelNotFoundError < StandardError; end
2426
class UnsupportedFunctionsError < StandardError; end
25-
class ConfigurationError < StandardError; end
26-
class UnauthorizedError < Error; end
27-
class PaymentRequiredError < Error; end
28-
class ServiceUnavailableError < Error; end
27+
28+
# Error classes for different HTTP status codes
2929
class BadRequestError < Error; end
30+
class ForbiddenError < Error; end
31+
class OverloadedError < Error; end
32+
class PaymentRequiredError < Error; end
3033
class RateLimitError < Error; end
3134
class ServerError < Error; end
35+
class ServiceUnavailableError < Error; end
36+
class UnauthorizedError < Error; end
3237

3338
# Faraday middleware that maps provider-specific API errors to RubyLLM errors.
3439
# Uses provider's parse_error method to extract meaningful error messages.
@@ -57,12 +62,17 @@ def parse_error(provider:, response:) # rubocop:disable Metrics/CyclomaticComple
5762
raise UnauthorizedError.new(response, message || 'Invalid API key - check your credentials')
5863
when 402
5964
raise PaymentRequiredError.new(response, message || 'Payment required - please top up your account')
65+
when 403
66+
raise ForbiddenError.new(response,
67+
message || 'Forbidden - you do not have permission to access this resource')
6068
when 429
6169
raise RateLimitError.new(response, message || 'Rate limit exceeded - please wait a moment')
6270
when 500
6371
raise ServerError.new(response, message || 'API server error - please try again')
6472
when 502..503
6573
raise ServiceUnavailableError.new(response, message || 'API server unavailable - please try again later')
74+
when 529
75+
raise OverloadedError.new(response, message || 'Service overloaded - please try again later')
6676
else
6777
raise Error.new(response, message || 'An unknown error occurred')
6878
end

lib/ruby_llm/provider.rb

Lines changed: 4 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ module RubyLLM
77
module Provider
88
# Common functionality for all LLM providers. Implements the core provider
99
# interface so specific providers only need to implement a few key methods.
10-
module Methods # rubocop:disable Metrics/ModuleLength
10+
module Methods
11+
extend Streaming
12+
1113
def complete(messages, tools:, temperature:, model:, &block) # rubocop:disable Metrics/MethodLength
1214
normalized_temperature = if capabilities.respond_to?(:normalize_temperature)
1315
capabilities.normalize_temperature(temperature, model)
@@ -80,19 +82,6 @@ def sync_response(payload)
8082
parse_completion_response response
8183
end
8284

83-
def stream_response(payload, &block)
84-
accumulator = StreamAccumulator.new
85-
86-
post stream_url, payload do |req|
87-
req.options.on_data = handle_stream do |chunk|
88-
accumulator.add chunk
89-
block.call chunk
90-
end
91-
end
92-
93-
accumulator.to_message
94-
end
95-
9685
def post(url, payload)
9786
connection.post url, payload do |req|
9887
req.headers.merge! headers
@@ -141,33 +130,6 @@ def connection # rubocop:disable Metrics/MethodLength,Metrics/AbcSize
141130
f.use :llm_errors, provider: self
142131
end
143132
end
144-
145-
def to_json_stream(&block) # rubocop:disable Metrics/MethodLength
146-
buffer = String.new
147-
parser = EventStreamParser::Parser.new
148-
149-
proc do |chunk, _bytes, env|
150-
if env && env.status != 200
151-
# Accumulate error chunks
152-
buffer << chunk
153-
begin
154-
error_data = JSON.parse(buffer)
155-
error_response = env.merge(body: error_data)
156-
ErrorMiddleware.parse_error(provider: self, response: error_response)
157-
rescue JSON::ParserError
158-
# Keep accumulating if we don't have complete JSON yet
159-
RubyLLM.logger.debug "Accumulating error chunk: #{chunk}"
160-
end
161-
else
162-
parser.feed(chunk) do |_type, data|
163-
unless data == '[DONE]'
164-
parsed_data = JSON.parse(data)
165-
block.call(parsed_data)
166-
end
167-
end
168-
end
169-
end
170-
end
171133
end
172134

173135
def try_parse_json(maybe_json)
@@ -207,6 +169,7 @@ def parse_data_uri(uri)
207169
class << self
208170
def extended(base)
209171
base.extend(Methods)
172+
base.extend(Streaming)
210173
end
211174

212175
def register(name, provider_module)

lib/ruby_llm/providers/anthropic/streaming.rb

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,6 @@ def stream_url
1111
completion_url
1212
end
1313

14-
def handle_stream(&block)
15-
to_json_stream do |data|
16-
block.call(build_chunk(data))
17-
end
18-
end
19-
2014
def build_chunk(data)
2115
Chunk.new(
2216
role: :assistant,
@@ -31,6 +25,18 @@ def build_chunk(data)
3125
def json_delta?(data)
3226
data['type'] == 'content_block_delta' && data.dig('delta', 'type') == 'input_json_delta'
3327
end
28+
29+
def parse_streaming_error(data)
30+
error_data = JSON.parse(data)
31+
return unless error_data['type'] == 'error'
32+
33+
case error_data.dig('error', 'type')
34+
when 'overloaded_error'
35+
[529, error_data['error']['message']]
36+
else
37+
[500, error_data['error']['message']]
38+
end
39+
end
3440
end
3541
end
3642
end

lib/ruby_llm/providers/gemini/chat.rb

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,13 @@ module RubyLLM
44
module Providers
55
module Gemini
66
# Chat methods for the Gemini API implementation
7-
module Chat # rubocop:disable Metrics/ModuleLength
8-
# Must be public for Provider to use
7+
module Chat
8+
def completion_url
9+
"models/#{@model}:generateContent"
10+
end
11+
912
def complete(messages, tools:, temperature:, model:, &block) # rubocop:disable Metrics/MethodLength
13+
@model = model
1014
payload = {
1115
contents: format_messages(messages),
1216
generationConfig: {
@@ -20,26 +24,15 @@ def complete(messages, tools:, temperature:, model:, &block) # rubocop:disable M
2024
@tools = tools
2125

2226
if block_given?
23-
stream_completion(model, payload, &block)
27+
stream_response payload, &block
2428
else
25-
generate_completion(model, payload)
29+
sync_response payload
2630
end
2731
end
2832

2933
# Format methods can be private
3034
private
3135

32-
def generate_completion(model, payload)
33-
url = "models/#{model}:generateContent"
34-
response = post(url, payload)
35-
result = parse_completion_response(response)
36-
37-
# If this contains a tool call, log it
38-
result.tool_calls.values.first if result.tool_call?
39-
40-
result
41-
end
42-
4336
def format_messages(messages)
4437
messages.map do |msg|
4538
{

lib/ruby_llm/providers/gemini/images.rb

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,13 @@ module Providers
55
module Gemini
66
# Image generation methods for the Gemini API implementation
77
module Images
8-
def images_url(model:)
9-
"models/#{model}:predict"
8+
def images_url
9+
"models/#{@model}:predict"
1010
end
1111

12-
def paint(prompt, model:, size:) # rubocop:disable Lint/UnusedMethodArgument
13-
payload = render_image_payload(prompt)
14-
15-
response = post(images_url(model:), payload)
16-
parse_image_response(response)
17-
end
18-
19-
def render_image_payload(prompt)
12+
def render_image_payload(prompt, model:, size:) # rubocop:disable Metrics/MethodLength
13+
RubyLLM.logger.debug "Ignoring size #{size}. Gemini does not support image size customization."
14+
@model = model
2015
{
2116
instances: [
2217
{

lib/ruby_llm/providers/gemini/streaming.rb

Lines changed: 35 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -5,93 +5,52 @@ module Providers
55
module Gemini
66
# Streaming methods for the Gemini API implementation
77
module Streaming
8-
# Need to make stream_completion public for chat.rb to access
9-
def stream_completion(model, payload, &block) # rubocop:disable Metrics/AbcSize,Metrics/MethodLength
10-
url = "models/#{model}:streamGenerateContent?alt=sse"
11-
accumulator = StreamAccumulator.new
12-
13-
post(url, payload) do |req|
14-
req.options.on_data = stream_handler(accumulator, &block)
15-
end
16-
17-
# If this is a tool call, immediately execute it and include the result
18-
message = accumulator.to_message
19-
if message.tool_call? && message.content.to_s.empty? && @tools && !@tools.empty?
20-
tool_call = message.tool_calls.values.first
21-
tool = @tools[tool_call.name.to_sym]
22-
23-
if tool
24-
tool_result = tool.call(tool_call.arguments)
25-
# Create a new chunk with the result
26-
result_chunk = Chunk.new(
27-
role: :assistant,
28-
content: "The result is #{tool_result}",
29-
model_id: message.model_id,
30-
input_tokens: message.input_tokens,
31-
output_tokens: message.output_tokens,
32-
tool_calls: message.tool_calls
33-
)
34-
35-
# Add to accumulator and call the block
36-
accumulator.add(result_chunk)
37-
block.call(result_chunk)
38-
end
39-
end
8+
def stream_url
9+
"models/#{@model}:streamGenerateContent?alt=sse"
10+
end
4011

41-
accumulator.to_message
12+
def build_chunk(data)
13+
Chunk.new(
14+
role: :assistant,
15+
model_id: extract_model_id(data),
16+
content: extract_content(data),
17+
input_tokens: extract_input_tokens(data),
18+
output_tokens: extract_output_tokens(data),
19+
tool_calls: extract_tool_calls(data)
20+
)
4221
end
4322

4423
private
4524

46-
# Handle streaming
47-
def stream_handler(accumulator, &block) # rubocop:disable Metrics/AbcSize,Metrics/CyclomaticComplexity,Metrics/MethodLength,Metrics/PerceivedComplexity
48-
to_json_stream do |data| # rubocop:disable Metrics/BlockLength
49-
next unless data['candidates']&.any?
50-
51-
candidate = data['candidates'][0]
52-
parts = candidate.dig('content', 'parts')
53-
model_id = data['modelVersion']
25+
def extract_model_id(data)
26+
data['modelVersion']
27+
end
5428

55-
# First attempt to extract tool calls
56-
tool_calls = nil
29+
def extract_content(data)
30+
return nil unless data['candidates']&.any?
5731

58-
# Check if any part contains a functionCall
59-
if parts&.any? { |p| p['functionCall'] }
60-
function_part = parts.find { |p| p['functionCall'] }
61-
function_data = function_part['functionCall']
32+
candidate = data['candidates'][0]
33+
parts = candidate.dig('content', 'parts')
34+
return nil unless parts
6235

63-
if function_data && function_data['name']
64-
# Create a tool call with proper structure - convert args to JSON string
65-
id = SecureRandom.uuid
66-
tool_calls = {
67-
id => ToolCall.new(
68-
id: id,
69-
name: function_data['name'],
70-
arguments: JSON.generate(function_data['args']) # Convert Hash to JSON string
71-
)
72-
}
73-
end
74-
end
36+
text_parts = parts.select { |p| p['text'] }
37+
text_parts.map { |p| p['text'] }.join if text_parts.any?
38+
end
7539

76-
# Extract text content (if any)
77-
text = nil
78-
if parts
79-
text_parts = parts.select { |p| p['text'] }
80-
text = text_parts.map { |p| p['text'] }.join if text_parts.any?
81-
end
40+
def extract_input_tokens(data)
41+
data.dig('usageMetadata', 'promptTokenCount')
42+
end
8243

83-
chunk = Chunk.new(
84-
role: :assistant,
85-
content: text,
86-
model_id: model_id,
87-
input_tokens: data.dig('usageMetadata', 'promptTokenCount'),
88-
output_tokens: data.dig('usageMetadata', 'candidatesTokenCount'),
89-
tool_calls: tool_calls
90-
)
44+
def extract_output_tokens(data)
45+
data.dig('usageMetadata', 'candidatesTokenCount')
46+
end
9147

92-
accumulator.add(chunk)
93-
block.call(chunk)
94-
end
48+
def parse_streaming_error(data)
49+
error_data = JSON.parse(data)
50+
[error_data['error']['code'], error_data['error']['message']]
51+
rescue JSON::ParserError => e
52+
RubyLLM.logger.debug "Failed to parse streaming error: #{e.message}"
53+
[500, "Failed to parse error: #{data}"]
9554
end
9655
end
9756
end

lib/ruby_llm/providers/openai/streaming.rb

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,15 @@ def stream_url
1111
completion_url
1212
end
1313

14-
def handle_stream(&block) # rubocop:disable Metrics/MethodLength
15-
to_json_stream do |data|
16-
block.call(
17-
Chunk.new(
18-
role: :assistant,
19-
model_id: data['model'],
20-
content: data.dig('choices', 0, 'delta', 'content'),
21-
tool_calls: parse_tool_calls(data.dig('choices', 0, 'delta', 'tool_calls'), parse_arguments: false),
22-
input_tokens: data.dig('usage', 'prompt_tokens'),
23-
output_tokens: data.dig('usage', 'completion_tokens')
24-
)
25-
)
26-
end
14+
def build_chunk(data)
15+
Chunk.new(
16+
role: :assistant,
17+
model_id: data['model'],
18+
content: data.dig('choices', 0, 'delta', 'content'),
19+
tool_calls: parse_tool_calls(data.dig('choices', 0, 'delta', 'tool_calls'), parse_arguments: false),
20+
input_tokens: data.dig('usage', 'prompt_tokens'),
21+
output_tokens: data.dig('usage', 'completion_tokens')
22+
)
2723
end
2824
end
2925
end

0 commit comments

Comments
 (0)