-
Notifications
You must be signed in to change notification settings - Fork 43
Implement guided generation for SystemLanguageModel
#59
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -758,21 +758,17 @@ extension LanguageModelSession { | |
|
|
||
| extension LanguageModelSession { | ||
| public struct ResponseStream<Content>: Sendable where Content: Generable, Content.PartiallyGenerated: Sendable { | ||
| private let content: Content | ||
| private let rawContent: GeneratedContent | ||
| private let fallbackSnapshot: Snapshot? | ||
| private let streaming: AsyncThrowingStream<Snapshot, any Error>? | ||
|
|
||
| init(content: Content, rawContent: GeneratedContent) { | ||
| self.content = content | ||
| self.rawContent = rawContent | ||
| self.fallbackSnapshot = Snapshot(content: content.asPartiallyGenerated(), rawContent: rawContent) | ||
| self.streaming = nil | ||
| } | ||
|
|
||
| init(stream: AsyncThrowingStream<Snapshot, any Error>) { | ||
| // Fallback values when consumers call collect() before any snapshots arrive | ||
| // These will be replaced by the last yielded snapshot during collect() | ||
| self.content = (try? Content(GeneratedContent(""))) ?? ("" as! Content) | ||
| self.rawContent = GeneratedContent("") | ||
| // When streaming, snapshots arrive from the upstream sequence, so no fallback is required. | ||
| self.fallbackSnapshot = nil | ||
| self.streaming = stream | ||
| } | ||
|
|
||
|
|
@@ -788,22 +784,14 @@ extension LanguageModelSession.ResponseStream: AsyncSequence { | |
|
|
||
| public struct AsyncIterator: AsyncIteratorProtocol { | ||
| private var hasYielded = false | ||
| private let content: Content | ||
| private let rawContent: GeneratedContent | ||
| private let fallbackSnapshot: Snapshot? | ||
| private var streamIterator: AsyncThrowingStream<Snapshot, any Error>.AsyncIterator? | ||
| private let useStream: Bool | ||
|
|
||
| init(content: Content, rawContent: GeneratedContent, stream: AsyncThrowingStream<Snapshot, any Error>?) { | ||
| self.content = content | ||
| self.rawContent = rawContent | ||
| if let stream { | ||
| let iterator = stream.makeAsyncIterator() | ||
| self.streamIterator = iterator | ||
| self.useStream = true | ||
| } else { | ||
| self.streamIterator = nil | ||
| self.useStream = false | ||
| } | ||
| init(fallbackSnapshot: Snapshot?, stream: AsyncThrowingStream<Snapshot, any Error>?) { | ||
| self.fallbackSnapshot = fallbackSnapshot | ||
| self.streamIterator = stream?.makeAsyncIterator() | ||
| self.useStream = stream != nil | ||
| } | ||
|
|
||
| public mutating func next() async throws -> Snapshot? { | ||
|
|
@@ -818,20 +806,17 @@ extension LanguageModelSession.ResponseStream: AsyncSequence { | |
| } | ||
| return nil | ||
| } else { | ||
| guard !hasYielded else { return nil } | ||
| guard !hasYielded, let fallbackSnapshot else { return nil } | ||
| hasYielded = true | ||
| return Snapshot( | ||
| content: content.asPartiallyGenerated(), | ||
| rawContent: rawContent | ||
| ) | ||
| return fallbackSnapshot | ||
| } | ||
| } | ||
|
|
||
| public typealias Element = Snapshot | ||
| } | ||
|
|
||
| public func makeAsyncIterator() -> AsyncIterator { | ||
| return AsyncIterator(content: content, rawContent: rawContent, stream: streaming) | ||
| return AsyncIterator(fallbackSnapshot: fallbackSnapshot, stream: streaming) | ||
| } | ||
|
|
||
| nonisolated public func collect() async throws -> sending LanguageModelSession.Response<Content> { | ||
|
|
@@ -855,9 +840,26 @@ extension LanguageModelSession.ResponseStream: AsyncSequence { | |
| ) | ||
| } | ||
| } | ||
|
|
||
| if let fallbackSnapshot { | ||
| let finalContent: Content | ||
| if let concrete = fallbackSnapshot.content as? Content { | ||
| finalContent = concrete | ||
| } else { | ||
| finalContent = try Content(fallbackSnapshot.rawContent) | ||
| } | ||
| return LanguageModelSession.Response( | ||
| content: finalContent, | ||
| rawContent: fallbackSnapshot.rawContent, | ||
| transcriptEntries: [] | ||
| ) | ||
| } | ||
|
|
||
| // As a last resort, return an empty payload. | ||
| let empty = GeneratedContent("") | ||
| return LanguageModelSession.Response( | ||
| content: content, | ||
| rawContent: rawContent, | ||
| content: try Content(empty), | ||
| rawContent: empty, | ||
| transcriptEntries: [] | ||
| ) | ||
|
Comment on lines
+858
to
864
|
||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fatalError will crash the application when a type cannot be converted to its partially generated form. This is a critical runtime failure that could occur during normal streaming operations when partial data is invalid. Consider throwing an error instead of using fatalError, or return a fallback value that indicates the conversion failed, allowing the caller to handle the error gracefully.