-
Notifications
You must be signed in to change notification settings - Fork 1.7k
[AI] Structured Output streaming #15652
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: pb-ah-generable
Are you sure you want to change the base?
Changes from 4 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -221,6 +221,12 @@ public final class GenerativeModel: Sendable { | |
| @available(macOS 12.0, *) | ||
| public func generateContentStream(_ content: [ModelContent]) throws | ||
| -> AsyncThrowingStream<GenerateContentResponse, Error> { | ||
| return try generateContentStream(content, generationConfig: generationConfig) | ||
| } | ||
|
|
||
| public func generateContentStream(_ content: [ModelContent], | ||
| generationConfig: GenerationConfig?) throws | ||
| -> AsyncThrowingStream<GenerateContentResponse, Error> { | ||
| try content.throwIfError() | ||
| let generateContentRequest = GenerateContentRequest( | ||
| model: modelResourceName, | ||
|
|
@@ -511,6 +517,109 @@ public final class GenerativeModel: Sendable { | |
| return Response(content: contentValue, rawContent: rawContent) | ||
| } | ||
|
|
||
| private func _generateObjectStream<Content>(parts: [any PartsRepresentable], | ||
| jsonSchemaProvider: @escaping @Sendable () throws | ||
| -> JSONObject, | ||
| contentProvider: @escaping @Sendable (ModelOutput) | ||
| throws -> Content) | ||
| -> AsyncThrowingStream<Response<Content>, Error> { | ||
| let content = parts.flatMap { $0.partsValue } | ||
| let modelContent = ModelContent(parts: content) | ||
|
|
||
| return AsyncThrowingStream { continuation in | ||
| Task { | ||
| do { | ||
| let jsonSchema = try jsonSchemaProvider() | ||
| let config = generationConfig(from: generationConfig, with: jsonSchema) | ||
|
|
||
| let responseStream = try generateContentStream( | ||
| [modelContent], | ||
| generationConfig: config | ||
| ) | ||
|
|
||
| var fullText = "" | ||
| // Accumulate the response text. | ||
| for try await chunk in responseStream { | ||
| if let text = chunk.text { | ||
| fullText += text | ||
| } | ||
|
|
||
| // Attempt to parse and yield partial results. | ||
| let cleanText = GenerativeModel.cleanedJSON(from: fullText) | ||
| let parser = PartialJSONParser(input: cleanText) | ||
| if let jsonValue = parser.parse() { | ||
| do { | ||
| let rawContent = ModelOutput(jsonValue: jsonValue) | ||
| let contentValue = try contentProvider(rawContent) | ||
| continuation.yield(Response(content: contentValue, rawContent: rawContent)) | ||
| } catch { | ||
| // Ignore conversion errors for partial content. | ||
| } | ||
paulb777 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| } | ||
|
Comment on lines
541
to
565
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current implementation accumulates the entire response text in A potential future optimization could be to design the While the current approach is simple and likely sufficient for typical response sizes, it's worth keeping this in mind for future performance improvements. |
||
| } | ||
|
|
||
| // TODO: Remove when extraneous '```json' prefix from JSON payload no longer returned. | ||
| let json = GenerativeModel.cleanedJSON(from: fullText) | ||
| let rawContent = try GenerativeModel.parseModelOutput(from: json) | ||
| let contentValue = try contentProvider(rawContent) | ||
|
|
||
| // Yield the final, strictly parsed result. | ||
| // This ensures the consumer receives the complete object validated against the schema, | ||
| // even if the partial parser yielded a similar result in the loop. | ||
| continuation.yield(Response(content: contentValue, rawContent: rawContent)) | ||
| continuation.finish() | ||
| } catch { | ||
| continuation.finish(throwing: error) | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| #if canImport(FoundationModels) | ||
| /// Generates a stream of `Content` objects from the model. | ||
| /// | ||
| /// - Parameters: | ||
| /// - type: A type to produce as the response. | ||
| /// - parts: The input(s) given to the model as a prompt (see ``PartsRepresentable`` for | ||
| /// conforming types). | ||
| /// - Returns: A stream containing the generated `Content` object. | ||
| @available(iOS 26.0, macOS 26.0, *) | ||
| @available(tvOS, unavailable) | ||
| @available(watchOS, unavailable) | ||
| public final func generateObjectStream<Content>(_ type: Content.Type = Content.self, | ||
| parts: any PartsRepresentable...) | ||
| -> AsyncThrowingStream<Response<Content>, Error> | ||
| where Content: FoundationModels.Generable { | ||
| return _generateObjectStream( | ||
| parts: parts, | ||
| jsonSchemaProvider: { try type.generationSchema.asGeminiJSONSchema() }, | ||
| contentProvider: { rawContent in | ||
| let generatedContent = rawContent.generatedContent | ||
| return try Content(generatedContent) | ||
| } | ||
| ) | ||
| } | ||
| #endif // canImport(FoundationModels) | ||
|
|
||
| /// Generates a stream of `Content` objects from the model. | ||
| /// | ||
| /// - Parameters: | ||
| /// - type: A type to produce as the response. | ||
| /// - parts: The input(s) given to the model as a prompt (see ``PartsRepresentable`` for | ||
| /// conforming types). | ||
| /// - Returns: A stream containing the generated `Content` object. | ||
| @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) | ||
| public final func generateObjectStream<Content>(_ type: Content.Type = Content.self, | ||
| parts: any PartsRepresentable...) | ||
| -> AsyncThrowingStream<Response<Content>, Error> | ||
| where Content: FirebaseGenerable { | ||
| return _generateObjectStream( | ||
| parts: parts, | ||
| jsonSchemaProvider: { try type.jsonSchema.asGeminiJSONSchema() }, | ||
| contentProvider: { try Content($0) } | ||
| ) | ||
| } | ||
|
|
||
| /// A structure that stores the output of a response call. | ||
| @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) | ||
| public struct Response<Content> { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,271 @@ | ||
| // Copyright 2025 Google LLC | ||
| // | ||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||
| // you may not use this file except in compliance with the License. | ||
| // You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, software | ||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| // See the License for the specific language governing permissions and | ||
| // limitations under the License. | ||
|
|
||
| import Foundation | ||
|
|
||
| /// A parser that attempts to parse partial JSON strings into `JSONValue`. | ||
| /// | ||
| /// This parser is tolerant of incomplete JSON structures (e.g., unclosed objects, arrays, strings) | ||
| /// and attempts to return the valid structure parsed so far. This is useful for streaming | ||
| /// applications where JSON is received in chunks. | ||
| @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) | ||
| final class PartialJSONParser { | ||
| private let input: [Character] | ||
| private var index: Int | ||
| private let length: Int | ||
|
|
||
| init(input: String) { | ||
| self.input = Array(input) | ||
| index = 0 | ||
| length = self.input.count | ||
| } | ||
|
|
||
| /// Parses the input string into a `JSONValue`. | ||
| /// Returns `nil` if the input is empty or cannot be parsed as a value. | ||
| func parse() -> JSONValue? { | ||
| skipWhitespace() | ||
| if index >= length { | ||
| return nil | ||
| } | ||
| return parseValue() | ||
| } | ||
|
|
||
| private func parseValue() -> JSONValue? { | ||
| skipWhitespace() | ||
| if index >= length { return nil } | ||
|
|
||
| let char = input[index] | ||
| switch char { | ||
| case "{": | ||
| return parseObject() | ||
| case "[": | ||
| return parseArray() | ||
| case "\"": | ||
| return parseString() | ||
| case "t": | ||
| return parseTrue() | ||
| case "f": | ||
| return parseFalse() | ||
| case "n": | ||
| return parseNull() | ||
| case "-", "0" ... "9": | ||
| return parseNumber() | ||
| default: | ||
| // If we encounter an unexpected character, we might be in an invalid state. | ||
| return nil | ||
| } | ||
| } | ||
|
|
||
| private func parseObject() -> JSONValue { | ||
| // Consume '{' | ||
| index += 1 | ||
|
|
||
| var object: JSONObject = [:] | ||
|
|
||
| while index < length { | ||
| skipWhitespace() | ||
| if index >= length { | ||
| // EOF inside object, return what we have | ||
| return .object(object) | ||
| } | ||
|
|
||
| let char = input[index] | ||
| if char == "}" { | ||
| index += 1 | ||
| return .object(object) | ||
| } | ||
|
|
||
| // Expect key | ||
| if char == "\"" { | ||
| // Parse key | ||
| // parseString returns .string(val) or .null (actually never .null if called on quote) | ||
| if case let .string(key) = parseString() { | ||
| skipWhitespace() | ||
|
|
||
| // Expect ':' | ||
| if index < length, input[index] == ":" { | ||
| index += 1 // consume ':' | ||
| if let value = parseValue() { | ||
| object[key] = value | ||
| } | ||
| // If value is nil (EOF), we ignore this key | ||
| } | ||
| } | ||
| } else { | ||
| // Unexpected character in object, maybe a comma? | ||
| if char == "," { | ||
| index += 1 | ||
| continue | ||
| } | ||
| // Invalid or unexpected, abort and return what we have | ||
| return .object(object) | ||
| } | ||
| } | ||
|
|
||
| return .object(object) | ||
| } | ||
|
|
||
| private func parseArray() -> JSONValue { | ||
| // Consume '[' | ||
| index += 1 | ||
|
|
||
| var array: [JSONValue] = [] | ||
|
|
||
| while index < length { | ||
| skipWhitespace() | ||
| if index >= length { | ||
| return .array(array) | ||
| } | ||
|
|
||
| let char = input[index] | ||
| if char == "]" { | ||
| index += 1 | ||
| return .array(array) | ||
| } | ||
|
|
||
| if char == "," { | ||
| index += 1 | ||
| continue | ||
| } | ||
|
|
||
| if let value = parseValue() { | ||
| array.append(value) | ||
| } else { | ||
| // EOF or invalid | ||
| return .array(array) | ||
| } | ||
| } | ||
|
|
||
| return .array(array) | ||
| } | ||
|
|
||
| private func parseString() -> JSONValue { | ||
| // Consume '"' | ||
| index += 1 | ||
|
|
||
| var string = "" | ||
| var escaped = false | ||
|
|
||
| while index < length { | ||
| let char = input[index] | ||
| index += 1 | ||
|
|
||
| if escaped { | ||
| // Handle basic escapes | ||
| switch char { | ||
| case "\"": string.append("\"") | ||
| case "\\": string.append("\\") | ||
| case "/": string.append("/") | ||
| case "b": string.append("\u{08}") | ||
| case "f": string.append("\u{0C}") | ||
| case "n": string.append("\n") | ||
| case "r": string.append("\r") | ||
| case "t": string.append("\t") | ||
| case "u": | ||
| // Unicode escape | ||
| // Need 4 chars | ||
| if index + 4 <= length { | ||
| let hex = String(input[index ..< index + 4]) | ||
| if let scalar = Int(hex, radix: 16), let uScalar = UnicodeScalar(scalar) { | ||
| string.append(Character(uScalar)) | ||
| index += 4 | ||
| } else { | ||
| // Invalid unicode, just append u... | ||
| string.append("\\u") | ||
| } | ||
| } else { | ||
| // Incomplete unicode | ||
| string.append("\\u") | ||
| // And we are probably near EOF | ||
| } | ||
| default: | ||
| string.append(char) | ||
| } | ||
| escaped = false | ||
| } else { | ||
| if char == "\"" { | ||
| return .string(string) | ||
| } else if char == "\\" { | ||
| escaped = true | ||
| } else { | ||
| string.append(char) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // Hit EOF without closing quote | ||
| // Return what we have | ||
| return .string(string) | ||
| } | ||
|
|
||
| private func parseNumber() -> JSONValue? { | ||
| let start = index | ||
| // Consume until non-numeric | ||
| // allowed: 0-9, -, +, ., e, E | ||
|
|
||
| while index < length { | ||
| let char = input[index] | ||
| if "0123456789-+.eE".contains(char) { | ||
| index += 1 | ||
| } else { | ||
| break | ||
| } | ||
| } | ||
|
|
||
| let numberString = String(input[start ..< index]) | ||
| if let double = Double(numberString) { | ||
| return .number(double) | ||
| } | ||
| // If partial number (e.g. "-"), return nil | ||
| return nil | ||
| } | ||
paulb777 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| private func parseTrue() -> JSONValue? { | ||
| // Expect "true" | ||
| if match("true") { return .bool(true) } | ||
| return nil | ||
| } | ||
|
|
||
| private func parseFalse() -> JSONValue? { | ||
| if match("false") { return .bool(false) } | ||
| return nil | ||
| } | ||
|
|
||
| private func parseNull() -> JSONValue? { | ||
| if match("null") { return .null } | ||
| return nil | ||
| } | ||
|
|
||
| private func match(_ string: String) -> Bool { | ||
| let chars = Array(string) | ||
| if index + chars.count <= length { | ||
| if Array(input[index ..< index + chars.count]) == chars { | ||
| index += chars.count | ||
| return true | ||
| } | ||
| } | ||
| return false | ||
| } | ||
|
|
||
| private func skipWhitespace() { | ||
| while index < length { | ||
| let char = input[index] | ||
| if char.isWhitespace { | ||
| index += 1 | ||
| } else { | ||
| break | ||
| } | ||
| } | ||
| } | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.