Skip to content

Commit d617209

Browse files
feat: support streaming uploads (#1)
1 parent d9ce7ce commit d617209

File tree

7 files changed

+254
-67
lines changed

7 files changed

+254
-67
lines changed

lib/openai/pooled_net_requester.rb

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,11 @@ def calibrate_socket_timeout(conn, deadline)
4848
#
4949
# @option request [Hash{String=>String}] :headers
5050
#
51+
# @param blk [Proc]
52+
#
5153
# @return [Net::HTTPGenericRequest]
5254
#
53-
def build_request(request)
55+
def build_request(request, &)
5456
method, url, headers, body = request.fetch_values(:method, :url, :headers, :body)
5557
req = Net::HTTPGenericRequest.new(
5658
method.to_s.upcase,
@@ -64,12 +66,14 @@ def build_request(request)
6466
case body
6567
in nil
6668
in String
67-
req.body = body
69+
req["content-length"] ||= body.bytesize.to_s unless req["transfer-encoding"]
70+
req.body_stream = OpenAI::Util::ReadIOAdapter.new(body, &)
6871
in StringIO
69-
req.body = body.string
70-
in IO
71-
body.rewind
72-
req.body_stream = body
72+
req["content-length"] ||= body.size.to_s unless req["transfer-encoding"]
73+
req.body_stream = OpenAI::Util::ReadIOAdapter.new(body, &)
74+
in IO | Enumerator
75+
req["transfer-encoding"] ||= "chunked" unless req["content-length"]
76+
req.body_stream = OpenAI::Util::ReadIOAdapter.new(body, &)
7377
end
7478

7579
req
@@ -97,7 +101,7 @@ def build_request(request)
97101

98102
pool =
99103
@mutex.synchronize do
100-
@pools[origin] ||= ConnectionPool.new(size: Etc.nprocessors) do
104+
@pools[origin] ||= ConnectionPool.new(size: @size) do
101105
self.class.connect(url)
102106
end
103107
end
@@ -128,14 +132,17 @@ def build_request(request)
128132
#
129133
def execute(request)
130134
url, deadline = request.fetch_values(:url, :deadline)
131-
req = self.class.build_request(request)
132135

133136
eof = false
134137
finished = false
135138
enum = Enumerator.new do |y|
136139
with_pool(url) do |conn|
137140
next if finished
138141

142+
req = self.class.build_request(request) do
143+
self.class.calibrate_socket_timeout(conn, deadline)
144+
end
145+
139146
self.class.calibrate_socket_timeout(conn, deadline)
140147
conn.start unless conn.started?
141148

@@ -168,8 +175,13 @@ def execute(request)
168175
[response, (response.body = body)]
169176
end
170177

171-
def initialize
178+
# @private
179+
#
180+
# @param size [Integer]
181+
#
182+
def initialize(size: Etc.nprocessors)
172183
@mutex = Mutex.new
184+
@size = size
173185
@pools = {}
174186
end
175187
end

lib/openai/util.rb

Lines changed: 134 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -399,41 +399,152 @@ def normalized_headers(*headers)
399399
end
400400
end
401401

402+
# @private
403+
#
404+
# An adapter that satisfies the IO interface required by `::IO.copy_stream`
405+
class ReadIOAdapter
406+
# @private
407+
#
408+
# @param max_len [Integer, nil]
409+
#
410+
# @return [String]
411+
#
412+
private def read_enum(max_len)
413+
case max_len
414+
in nil
415+
@stream.to_a.join
416+
in Integer
417+
@buf << @stream.next while @buf.length < max_len
418+
@buf.slice!(..max_len)
419+
end
420+
rescue StopIteration
421+
@stream = nil
422+
@buf.slice!(0..)
423+
end
424+
425+
# @private
426+
#
427+
# @param max_len [Integer, nil]
428+
# @param out_string [String, nil]
429+
#
430+
# @return [String, nil]
431+
#
432+
def read(max_len = nil, out_string = nil)
433+
case @stream
434+
in nil
435+
nil
436+
in IO | StringIO
437+
@stream.read(max_len, out_string)
438+
in Enumerator
439+
read = read_enum(max_len)
440+
case out_string
441+
in String
442+
out_string.replace(read)
443+
in nil
444+
read
445+
end
446+
end
447+
.tap(&@blk)
448+
end
449+
450+
# @private
451+
#
452+
# @param stream [String, IO, StringIO, Enumerable]
453+
# @param blk [Proc]
454+
#
455+
def initialize(stream, &blk)
456+
@stream = stream.is_a?(String) ? StringIO.new(stream) : stream
457+
@buf = String.new.b
458+
@blk = blk
459+
end
460+
end
461+
462+
class << self
463+
# @param blk [Proc]
464+
#
465+
# @return [Enumerable]
466+
#
467+
def string_io(&blk)
468+
Enumerator.new do |y|
469+
y.define_singleton_method(:write) do
470+
self << _1.clone
471+
_1.bytesize
472+
end
473+
474+
blk.call(y)
475+
end
476+
end
477+
end
478+
402479
class << self
403480
# @private
404481
#
405-
# @param io [StringIO]
482+
# @param y [Enumerator::Yielder]
406483
# @param boundary [String]
407484
# @param key [Symbol, String]
408485
# @param val [Object]
409486
#
410-
private def encode_multipart_formdata(io, boundary:, key:, val:)
411-
io << "--#{boundary}\r\n"
412-
io << "Content-Disposition: form-data"
487+
private def encode_multipart_formdata(y, boundary:, key:, val:)
488+
y << "--#{boundary}\r\n"
489+
y << "Content-Disposition: form-data"
413490
unless key.nil?
414491
name = ERB::Util.url_encode(key.to_s)
415-
io << "; name=\"#{name}\""
492+
y << "; name=\"#{name}\""
416493
end
417494
if val.is_a?(IO)
418495
filename = ERB::Util.url_encode(File.basename(val.to_path))
419-
io << "; filename=\"#{filename}\""
496+
y << "; filename=\"#{filename}\""
420497
end
421-
io << "\r\n"
498+
y << "\r\n"
422499
case val
423-
in IO | StringIO
424-
io << "Content-Type: application/octet-stream\r\n\r\n"
425-
IO.copy_stream(val, io)
500+
in IO
501+
y << "Content-Type: application/octet-stream\r\n\r\n"
502+
IO.copy_stream(val, y)
503+
in StringIO
504+
y << "Content-Type: application/octet-stream\r\n\r\n"
505+
y << val.string
426506
in String
427-
io << "Content-Type: application/octet-stream\r\n\r\n"
428-
io << val.to_s
507+
y << "Content-Type: application/octet-stream\r\n\r\n"
508+
y << val.to_s
429509
in true | false | Integer | Float | Symbol
430-
io << "Content-Type: text/plain\r\n\r\n"
431-
io << val.to_s
510+
y << "Content-Type: text/plain\r\n\r\n"
511+
y << val.to_s
432512
else
433-
io << "Content-Type: application/json\r\n\r\n"
434-
io << JSON.fast_generate(val)
513+
y << "Content-Type: application/json\r\n\r\n"
514+
y << JSON.fast_generate(val)
435515
end
436-
io << "\r\n"
516+
y << "\r\n"
517+
end
518+
519+
# @private
520+
#
521+
# @param body [Object]
522+
#
523+
# @return [Array(String, Enumerable)]
524+
#
525+
private def encode_multipart_streaming(body)
526+
boundary = SecureRandom.urlsafe_base64(60)
527+
528+
strio = string_io do |y|
529+
case body
530+
in Hash
531+
body.each do |key, val|
532+
case val
533+
in Array if val.all? { primitive?(_1) }
534+
val.each do |v|
535+
encode_multipart_formdata(y, boundary: boundary, key: key, val: v)
536+
end
537+
else
538+
encode_multipart_formdata(y, boundary: boundary, key: key, val: val)
539+
end
540+
end
541+
else
542+
encode_multipart_formdata(y, boundary: boundary, key: nil, val: body)
543+
end
544+
y << "--#{boundary}--\r\n"
545+
end
546+
547+
[boundary, strio]
437548
end
438549

439550
# @private
@@ -449,37 +560,11 @@ def encode_content(headers, body)
449560
in ["application/json", Hash | Array]
450561
[headers, JSON.fast_generate(body)]
451562
in [%r{^multipart/form-data}, Hash | IO | StringIO]
452-
boundary = SecureRandom.urlsafe_base64(60)
453-
strio = StringIO.new.tap do |io|
454-
case body
455-
in Hash
456-
body.each do |key, val|
457-
case val
458-
in Array if val.all? { primitive?(_1) }
459-
val.each do |v|
460-
encode_multipart_formdata(io, boundary: boundary, key: key, val: v)
461-
end
462-
else
463-
encode_multipart_formdata(io, boundary: boundary, key: key, val: val)
464-
end
465-
end
466-
else
467-
encode_multipart_formdata(io, boundary: boundary, key: nil, val: body)
468-
end
469-
io << "--#{boundary}--\r\n"
470-
io.rewind
471-
end
472-
headers = {
473-
**headers,
474-
"content-type" => "#{content_type}; boundary=#{boundary}",
475-
"transfer-encoding" => "chunked"
476-
}
563+
boundary, strio = encode_multipart_streaming(body)
564+
headers = {**headers, "content-type" => "#{content_type}; boundary=#{boundary}"}
477565
[headers, strio]
478566
in [_, StringIO]
479567
[headers, body.string]
480-
in [_, IO]
481-
headers = {**headers, "transfer-encoding" => "chunked"}
482-
[headers, body]
483568
else
484569
[headers, body]
485570
end
@@ -589,8 +674,9 @@ def decode_lines(enum)
589674

590675
chain_fused(enum) do |y|
591676
enum.each do |row|
677+
offset = buffer.bytesize
592678
buffer << row
593-
while (match = re.match(buffer, cr_seen.to_i))
679+
while (match = re.match(buffer, cr_seen&.to_i || offset))
594680
case [match.captures.first, cr_seen]
595681
in ["\r", nil]
596682
cr_seen = match.end(1)
@@ -600,6 +686,7 @@ def decode_lines(enum)
600686
else
601687
y << buffer.slice!(..(match.end(1).pred))
602688
end
689+
offset = 0
603690
cr_seen = nil
604691
end
605692
end
@@ -637,7 +724,7 @@ def decode_sse(lines)
637724
in "event"
638725
current.merge!(event: value)
639726
in "data"
640-
(current[:data] ||= String.new.b) << value << "\n"
727+
(current[:data] ||= String.new.b) << (value << "\n")
641728
in "id" unless value.include?("\0")
642729
current.merge!(id: value)
643730
in "retry" if /^\d+$/ =~ value

rbi/lib/openai/pooled_net_requester.rbi

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,11 @@ module OpenAI
1515
def calibrate_socket_timeout(conn, deadline)
1616
end
1717

18-
sig { params(request: OpenAI::PooledNetRequester::RequestShape).returns(Net::HTTPGenericRequest) }
19-
def build_request(request)
18+
sig do
19+
params(request: OpenAI::PooledNetRequester::RequestShape, blk: T.proc.params(arg0: String).void)
20+
.returns(Net::HTTPGenericRequest)
21+
end
22+
def build_request(request, &blk)
2023
end
2124
end
2225

@@ -31,8 +34,8 @@ module OpenAI
3134
def execute(request)
3235
end
3336

34-
sig { returns(T.attached_class) }
35-
def self.new
37+
sig { params(size: Integer).returns(T.attached_class) }
38+
def self.new(size: Etc.nprocessors)
3639
end
3740
end
3841
end

rbi/lib/openai/util.rbi

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,41 @@ module OpenAI
130130
end
131131
end
132132

133+
class ReadIOAdapter
134+
sig { params(max_len: T.nilable(Integer)).returns(String) }
135+
private def read_enum(max_len)
136+
end
137+
138+
sig { params(max_len: T.nilable(Integer), out_string: T.nilable(String)).returns(T.nilable(String)) }
139+
def read(max_len = nil, out_string = nil)
140+
end
141+
142+
sig do
143+
params(
144+
stream: T.any(String, IO, StringIO, T::Enumerable[String]),
145+
blk: T.proc.params(arg0: String).void
146+
)
147+
.returns(T.attached_class)
148+
end
149+
def self.new(stream, &blk)
150+
end
151+
end
152+
153+
class << self
154+
sig { params(blk: T.proc.params(y: Enumerator::Yielder).void).returns(T::Enumerable[String]) }
155+
def string_io(&blk)
156+
end
157+
end
158+
133159
class << self
134-
sig { params(io: StringIO, boundary: String, key: T.any(Symbol, String), val: T.anything).void }
135-
private def encode_multipart_formdata(io, boundary:, key:, val:)
160+
sig do
161+
params(y: Enumerator::Yielder, boundary: String, key: T.any(Symbol, String), val: T.anything).void
162+
end
163+
private def encode_multipart_formdata(y, boundary:, key:, val:)
164+
end
165+
166+
sig { params(body: T.anything).returns([String, T::Enumerable[String]]) }
167+
private def encode_multipart_streaming(body)
136168
end
137169

138170
sig { params(headers: T::Hash[String, String], body: T.anything).returns(T.anything) }

0 commit comments

Comments
 (0)