Skip to content
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions FirebaseAI/Sources/GenerativeModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,12 @@ public final class GenerativeModel: Sendable {
@available(macOS 12.0, *)
public func generateContentStream(_ content: [ModelContent]) throws
-> AsyncThrowingStream<GenerateContentResponse, Error> {
return try generateContentStream(content, generationConfig: generationConfig)
}

public func generateContentStream(_ content: [ModelContent],
generationConfig: GenerationConfig?) throws
-> AsyncThrowingStream<GenerateContentResponse, Error> {
try content.throwIfError()
let generateContentRequest = GenerateContentRequest(
model: modelResourceName,
Expand Down Expand Up @@ -511,6 +517,109 @@ public final class GenerativeModel: Sendable {
return Response(content: contentValue, rawContent: rawContent)
}

private func _generateObjectStream<Content>(parts: [any PartsRepresentable],
jsonSchemaProvider: @escaping @Sendable () throws
-> JSONObject,
contentProvider: @escaping @Sendable (ModelOutput)
throws -> Content)
-> AsyncThrowingStream<Response<Content>, Error> {
let content = parts.flatMap { $0.partsValue }
let modelContent = ModelContent(parts: content)

return AsyncThrowingStream { continuation in
Task {
do {
let jsonSchema = try jsonSchemaProvider()
let config = generationConfig(from: generationConfig, with: jsonSchema)

let responseStream = try generateContentStream(
[modelContent],
generationConfig: config
)

var fullText = ""
// Accumulate the response text.
for try await chunk in responseStream {
if let text = chunk.text {
fullText += text
}

// Attempt to parse and yield partial results.
let cleanText = GenerativeModel.cleanedJSON(from: fullText)
let parser = PartialJSONParser(input: cleanText)
if let jsonValue = parser.parse() {
do {
let rawContent = ModelOutput(jsonValue: jsonValue)
let contentValue = try contentProvider(rawContent)
continuation.yield(Response(content: contentValue, rawContent: rawContent))
} catch {
// Ignore conversion errors for partial content.
}
}
Comment on lines 541 to 565
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation accumulates the entire response text in fullText and re-parses it from the beginning with each new chunk. For very large streaming responses, this could become inefficient in terms of both memory (for fullText) and CPU (for re-parsing).

A potential future optimization could be to design the PartialJSONParser to work with a stream of characters or chunks directly, avoiding the need to buffer the entire response and re-parse known-good sections.

While the current approach is simple and likely sufficient for typical response sizes, it's worth keeping this in mind for future performance improvements.

}

// TODO: Remove when extraneous '```json' prefix from JSON payload no longer returned.
let json = GenerativeModel.cleanedJSON(from: fullText)
let rawContent = try GenerativeModel.parseModelOutput(from: json)
let contentValue = try contentProvider(rawContent)

// Yield the final, strictly parsed result.
// This ensures the consumer receives the complete object validated against the schema,
// even if the partial parser yielded a similar result in the loop.
continuation.yield(Response(content: contentValue, rawContent: rawContent))
continuation.finish()
} catch {
continuation.finish(throwing: error)
}
}
}
}

#if canImport(FoundationModels)
/// Generates a stream of `Content` objects from the model.
///
/// - Parameters:
/// - type: A type to produce as the response.
/// - parts: The input(s) given to the model as a prompt (see ``PartsRepresentable`` for
/// conforming types).
/// - Returns: A stream containing the generated `Content` object.
@available(iOS 26.0, macOS 26.0, *)
@available(tvOS, unavailable)
@available(watchOS, unavailable)
public final func generateObjectStream<Content>(_ type: Content.Type = Content.self,
parts: any PartsRepresentable...)
-> AsyncThrowingStream<Response<Content>, Error>
where Content: FoundationModels.Generable {
return _generateObjectStream(
parts: parts,
jsonSchemaProvider: { try type.generationSchema.asGeminiJSONSchema() },
contentProvider: { rawContent in
let generatedContent = rawContent.generatedContent
return try Content(generatedContent)
}
)
}
#endif // canImport(FoundationModels)

/// Generates a stream of `Content` objects from the model.
///
/// - Parameters:
/// - type: A type to produce as the response.
/// - parts: The input(s) given to the model as a prompt (see ``PartsRepresentable`` for
/// conforming types).
/// - Returns: A stream containing the generated `Content` object.
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public final func generateObjectStream<Content>(_ type: Content.Type = Content.self,
parts: any PartsRepresentable...)
-> AsyncThrowingStream<Response<Content>, Error>
where Content: FirebaseGenerable {
return _generateObjectStream(
parts: parts,
jsonSchemaProvider: { try type.jsonSchema.asGeminiJSONSchema() },
contentProvider: { try Content($0) }
)
}

/// A structure that stores the output of a response call.
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public struct Response<Content> {
Expand Down
271 changes: 271 additions & 0 deletions FirebaseAI/Sources/PartialJSONParser.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import Foundation

/// A parser that attempts to parse partial JSON strings into `JSONValue`.
///
/// This parser is tolerant of incomplete JSON structures (e.g., unclosed objects, arrays, strings)
/// and attempts to return the valid structure parsed so far. This is useful for streaming
/// applications where JSON is received in chunks.
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
final class PartialJSONParser {
private let input: [Character]
private var index: Int
private let length: Int

init(input: String) {
self.input = Array(input)
index = 0
length = self.input.count
}

/// Parses the input string into a `JSONValue`.
/// Returns `nil` if the input is empty or cannot be parsed as a value.
func parse() -> JSONValue? {
skipWhitespace()
if index >= length {
return nil
}
return parseValue()
}

private func parseValue() -> JSONValue? {
skipWhitespace()
if index >= length { return nil }

let char = input[index]
switch char {
case "{":
return parseObject()
case "[":
return parseArray()
case "\"":
return parseString()
case "t":
return parseTrue()
case "f":
return parseFalse()
case "n":
return parseNull()
case "-", "0" ... "9":
return parseNumber()
default:
// If we encounter an unexpected character, we might be in an invalid state.
return nil
}
}

private func parseObject() -> JSONValue {
// Consume '{'
index += 1

var object: JSONObject = [:]

while index < length {
skipWhitespace()
if index >= length {
// EOF inside object, return what we have
return .object(object)
}

let char = input[index]
if char == "}" {
index += 1
return .object(object)
}

// Expect key
if char == "\"" {
// Parse key
// parseString returns .string(val) or .null (actually never .null if called on quote)
if case let .string(key) = parseString() {
skipWhitespace()

// Expect ':'
if index < length, input[index] == ":" {
index += 1 // consume ':'
if let value = parseValue() {
object[key] = value
}
// If value is nil (EOF), we ignore this key
}
}
} else {
// Unexpected character in object, maybe a comma?
if char == "," {
index += 1
continue
}
// Invalid or unexpected, abort and return what we have
return .object(object)
}
}

return .object(object)
}

private func parseArray() -> JSONValue {
// Consume '['
index += 1

var array: [JSONValue] = []

while index < length {
skipWhitespace()
if index >= length {
return .array(array)
}

let char = input[index]
if char == "]" {
index += 1
return .array(array)
}

if char == "," {
index += 1
continue
}

if let value = parseValue() {
array.append(value)
} else {
// EOF or invalid
return .array(array)
}
}

return .array(array)
}

private func parseString() -> JSONValue {
// Consume '"'
index += 1

var string = ""
var escaped = false

while index < length {
let char = input[index]
index += 1

if escaped {
// Handle basic escapes
switch char {
case "\"": string.append("\"")
case "\\": string.append("\\")
case "/": string.append("/")
case "b": string.append("\u{08}")
case "f": string.append("\u{0C}")
case "n": string.append("\n")
case "r": string.append("\r")
case "t": string.append("\t")
case "u":
// Unicode escape
// Need 4 chars
if index + 4 <= length {
let hex = String(input[index ..< index + 4])
if let scalar = Int(hex, radix: 16), let uScalar = UnicodeScalar(scalar) {
string.append(Character(uScalar))
index += 4
} else {
// Invalid unicode, just append u...
string.append("\\u")
}
} else {
// Incomplete unicode
string.append("\\u")
// And we are probably near EOF
}
default:
string.append(char)
}
escaped = false
} else {
if char == "\"" {
return .string(string)
} else if char == "\\" {
escaped = true
} else {
string.append(char)
}
}
}

// Hit EOF without closing quote
// Return what we have
return .string(string)
}

private func parseNumber() -> JSONValue? {
let start = index
// Consume until non-numeric
// allowed: 0-9, -, +, ., e, E

while index < length {
let char = input[index]
if "0123456789-+.eE".contains(char) {
index += 1
} else {
break
}
}

let numberString = String(input[start ..< index])
if let double = Double(numberString) {
return .number(double)
}
// If partial number (e.g. "-"), return nil
return nil
}

private func parseTrue() -> JSONValue? {
// Expect "true"
if match("true") { return .bool(true) }
return nil
}

private func parseFalse() -> JSONValue? {
if match("false") { return .bool(false) }
return nil
}

private func parseNull() -> JSONValue? {
if match("null") { return .null }
return nil
}

private func match(_ string: String) -> Bool {
let chars = Array(string)
if index + chars.count <= length {
if Array(input[index ..< index + chars.count]) == chars {
index += chars.count
return true
}
}
return false
}

private func skipWhitespace() {
while index < length {
let char = input[index]
if char.isWhitespace {
index += 1
} else {
break
}
}
}
}
Loading
Loading