Skip to content

Commit 2432eb0

Browse files
chore: more accurate generic params for stream classes (#8)
1 parent accddc2 commit 2432eb0

File tree

7 files changed

+41
-12
lines changed

7 files changed

+41
-12
lines changed

lib/openai/base_stream.rb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ module OpenAI
1616
#
1717
# messages => Array
1818
# ```
19-
class BaseStream
19+
module BaseStream
2020
# @return [void]
2121
#
2222
def close = OpenAI::Util.close_fused!(@iterator)

lib/openai/stream.rb

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ module OpenAI
1616
#
1717
# messages => Array
1818
# ```
19-
class Stream < OpenAI::BaseStream
19+
class Stream
20+
include OpenAI::BaseStream
21+
2022
# @private
2123
#
2224
# @return [Enumerable]

rbi/lib/openai/base_client.rbi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ module OpenAI
2222
body: T.nilable(T.anything),
2323
unwrap: T.nilable(Symbol),
2424
page: T.nilable(T::Class[OpenAI::BasePage[OpenAI::BaseModel]]),
25-
stream: T.nilable(T::Class[OpenAI::BaseStream[OpenAI::BaseModel]]),
25+
stream: T.nilable(T::Class[OpenAI::BaseStream[T.anything, OpenAI::BaseModel]]),
2626
model: T.nilable(OpenAI::Converter::Input),
2727
options: T.nilable(T.any(OpenAI::RequestOptions, T::Hash[Symbol, T.anything]))
2828
}
@@ -148,7 +148,7 @@ module OpenAI
148148
body: T.nilable(T.anything),
149149
unwrap: T.nilable(Symbol),
150150
page: T.nilable(T::Class[OpenAI::BasePage[OpenAI::BaseModel]]),
151-
stream: T.nilable(T::Class[OpenAI::BaseStream[OpenAI::BaseModel]]),
151+
stream: T.nilable(T::Class[OpenAI::BaseStream[T.anything, OpenAI::BaseModel]]),
152152
model: T.nilable(OpenAI::Converter::Input),
153153
options: T.nilable(T.any(OpenAI::RequestOptions, T::Hash[Symbol, T.anything]))
154154
)

rbi/lib/openai/base_stream.rbi

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
# typed: strong
22

33
module OpenAI
4-
class BaseStream
4+
module BaseStream
5+
Message = type_member(:in)
56
Elem = type_member(:out)
67

78
sig { void }
@@ -28,11 +29,11 @@ module OpenAI
2829
url: URI::Generic,
2930
status: Integer,
3031
response: Net::HTTPResponse,
31-
messages: T::Enumerable[OpenAI::Util::SSEMessage]
32+
messages: T::Enumerable[Message]
3233
)
33-
.returns(T.attached_class)
34+
.void
3435
end
35-
def self.new(model:, url:, status:, response:, messages:)
36+
def initialize(model:, url:, status:, response:, messages:)
3637
end
3738
end
3839
end

rbi/lib/openai/stream.rbi

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,27 @@
11
# typed: strong
22

33
module OpenAI
4-
class Stream < OpenAI::BaseStream
4+
class Stream
5+
include OpenAI::BaseStream
6+
7+
Message = type_member(:in) { {fixed: OpenAI::Util::SSEMessage} }
58
Elem = type_member(:out)
69

710
sig { override.returns(T::Enumerable[Elem]) }
811
private def iterator
912
end
13+
14+
sig do
15+
params(
16+
model: T.any(T::Class[T.anything], OpenAI::Converter),
17+
url: URI::Generic,
18+
status: Integer,
19+
response: Net::HTTPResponse,
20+
messages: T::Enumerable[OpenAI::Util::SSEMessage]
21+
)
22+
.returns(T.attached_class)
23+
end
24+
def self.new(model:, url:, status:, response:, messages:)
25+
end
1026
end
1127
end

sig/openai/base_stream.rbs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
module OpenAI
2-
class BaseStream[Elem]
2+
module BaseStream[Message, Elem]
33
def close: -> void
44

55
private def iterator: -> Enumerable[Elem]
@@ -15,7 +15,7 @@ module OpenAI
1515
url: URI::Generic,
1616
status: Integer,
1717
response: top,
18-
messages: Enumerable[OpenAI::Util::sse_message]
18+
messages: Enumerable[Message]
1919
) -> void
2020
end
2121
end

sig/openai/stream.rbs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11
module OpenAI
2-
class Stream[Elem] < OpenAI::BaseStream[Elem]
2+
class Stream[Elem]
3+
include OpenAI::BaseStream[OpenAI::Util::sse_message, Elem]
4+
35
private def iterator: -> Enumerable[Elem]
6+
7+
def initialize: (
8+
model: Class | OpenAI::Converter,
9+
url: URI::Generic,
10+
status: Integer,
11+
response: top,
12+
messages: Enumerable[OpenAI::Util::sse_message]
13+
) -> void
414
end
515
end

0 commit comments

Comments
 (0)