Skip to content

Commit a883583

Browse files
committed
[AI] Add GenerativeModelSession class
1 parent 6e0fdae commit a883583

File tree

5 files changed

+360
-51
lines changed

5 files changed

+360
-51
lines changed
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// Copyright 2026 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#if canImport(FoundationModels)
16+
import Foundation
17+
import FoundationModels
18+
19+
@available(iOS 26.0, macOS 26.0, *)
20+
@available(tvOS, unavailable)
21+
@available(watchOS, unavailable)
22+
extension GenerationSchema {
23+
/// Returns a Gemini-compatible JSON Schema of this `GenerationSchema`.
24+
func toGeminiJSONSchema() throws -> JSONObject {
25+
let generationSchemaData = try JSONEncoder().encode(self)
26+
var jsonSchema = try JSONDecoder().decode(JSONObject.self, from: generationSchemaData)
27+
if let propertyOrdering = jsonSchema.removeValue(forKey: "x-order") {
28+
jsonSchema["propertyOrdering"] = propertyOrdering
29+
}
30+
31+
return jsonSchema
32+
}
33+
}
34+
#endif // canImport(FoundationModels)

FirebaseAI/Sources/GenerationConfig.swift

Lines changed: 61 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,45 +19,45 @@ import Foundation
1919
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
2020
public struct GenerationConfig: Sendable {
2121
/// Controls the degree of randomness in token selection.
22-
let temperature: Float?
22+
var temperature: Float?
2323

2424
/// Controls diversity of generated text.
25-
let topP: Float?
25+
var topP: Float?
2626

2727
/// Limits the number of highest probability words considered.
28-
let topK: Int?
28+
var topK: Int?
2929

3030
/// The number of response variations to return.
31-
let candidateCount: Int?
31+
var candidateCount: Int?
3232

3333
/// Maximum number of tokens that can be generated in the response.
34-
let maxOutputTokens: Int?
34+
var maxOutputTokens: Int?
3535

3636
/// Controls the likelihood of repeating the same words or phrases already generated in the text.
37-
let presencePenalty: Float?
37+
var presencePenalty: Float?
3838

3939
/// Controls the likelihood of repeating words, with the penalty increasing for each repetition.
40-
let frequencyPenalty: Float?
40+
var frequencyPenalty: Float?
4141

4242
/// A set of up to 5 `String`s that will stop output generation.
43-
let stopSequences: [String]?
43+
var stopSequences: [String]?
4444

4545
/// Output response MIME type of the generated candidate text.
46-
let responseMIMEType: String?
46+
var responseMIMEType: String?
4747

4848
/// Output schema of the generated candidate text.
49-
let responseSchema: Schema?
49+
var responseSchema: Schema?
5050

5151
/// Output schema of the generated response in [JSON Schema](https://json-schema.org/) format.
5252
///
5353
/// If set, `responseSchema` must be omitted and `responseMIMEType` is required.
54-
let responseJSONSchema: JSONObject?
54+
var responseJSONSchema: JSONObject?
5555

5656
/// Supported modalities of the response.
57-
let responseModalities: [ResponseModality]?
57+
var responseModalities: [ResponseModality]?
5858

5959
/// Configuration for controlling the "thinking" behavior of compatible Gemini models.
60-
let thinkingConfig: ThinkingConfig?
60+
var thinkingConfig: ThinkingConfig?
6161

6262
/// Creates a new `GenerationConfig` value.
6363
///
@@ -203,6 +203,54 @@ public struct GenerationConfig: Sendable {
203203
self.responseModalities = responseModalities
204204
self.thinkingConfig = thinkingConfig
205205
}
206+
207+
/// Merges two configurations, giving precedence to values found in the `overrides` parameter.
208+
///
209+
/// - Parameters:
210+
/// - base: The foundational configuration (e.g., model-level defaults).
211+
/// - overrides: The configuration containing values that should supersede the base (e.g.,
212+
/// request-level specific settings).
213+
/// - Returns: A merged `GenerationConfig` prioritizing `overrides`, or `nil` if both inputs are
214+
/// `nil`.
215+
static func merge(_ base: GenerationConfig?,
216+
with overrides: GenerationConfig?) -> GenerationConfig? {
217+
// 1. If the base config is missing, return the overrides (which might be nil).
218+
guard let baseConfig = base else {
219+
return overrides
220+
}
221+
222+
// 2. If overrides are missing, strictly return the base.
223+
guard let overrideConfig = overrides else {
224+
return baseConfig
225+
}
226+
227+
// 3. Start with a copy of the base config.
228+
var config = baseConfig
229+
230+
// 4. Overwrite with any non-nil values found in the overrides.
231+
config.temperature = overrideConfig.temperature ?? config.temperature
232+
config.topP = overrideConfig.topP ?? config.topP
233+
config.topK = overrideConfig.topK ?? config.topK
234+
config.candidateCount = overrideConfig.candidateCount ?? config.candidateCount
235+
config.maxOutputTokens = overrideConfig.maxOutputTokens ?? config.maxOutputTokens
236+
config.presencePenalty = overrideConfig.presencePenalty ?? config.presencePenalty
237+
config.frequencyPenalty = overrideConfig.frequencyPenalty ?? config.frequencyPenalty
238+
config.stopSequences = overrideConfig.stopSequences ?? config.stopSequences
239+
config.responseMIMEType = overrideConfig.responseMIMEType ?? config.responseMIMEType
240+
config.responseModalities = overrideConfig.responseModalities ?? config.responseModalities
241+
config.thinkingConfig = overrideConfig.thinkingConfig ?? config.thinkingConfig
242+
243+
// 5. Handle Schema mutual exclusivity with precedence for `responseFirebaseGenerationSchema`.
244+
if let responseJSONSchema = overrideConfig.responseJSONSchema {
245+
config.responseJSONSchema = responseJSONSchema
246+
config.responseSchema = nil
247+
} else if let responseSchema = overrideConfig.responseSchema {
248+
config.responseSchema = responseSchema
249+
config.responseJSONSchema = nil
250+
}
251+
252+
return config
253+
}
206254
}
207255

208256
// MARK: - Codable Conformances

FirebaseAI/Sources/GenerativeModel.swift

Lines changed: 46 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -144,44 +144,7 @@ public final class GenerativeModel: Sendable {
144144
/// - Throws: A ``GenerateContentError`` if the request failed.
145145
public func generateContent(_ content: [ModelContent]) async throws
146146
-> GenerateContentResponse {
147-
try content.throwIfError()
148-
let response: GenerateContentResponse
149-
let generateContentRequest = GenerateContentRequest(
150-
model: modelResourceName,
151-
contents: content,
152-
generationConfig: generationConfig,
153-
safetySettings: safetySettings,
154-
tools: tools,
155-
toolConfig: toolConfig,
156-
systemInstruction: systemInstruction,
157-
apiConfig: apiConfig,
158-
apiMethod: .generateContent,
159-
options: requestOptions
160-
)
161-
do {
162-
response = try await generativeAIService.loadRequest(request: generateContentRequest)
163-
} catch {
164-
throw GenerativeModel.generateContentError(from: error)
165-
}
166-
167-
// Check the prompt feedback to see if the prompt was blocked.
168-
if response.promptFeedback?.blockReason != nil {
169-
throw GenerateContentError.promptBlocked(response: response)
170-
}
171-
172-
// Check to see if an error should be thrown for stop reason.
173-
if let reason = response.candidates.first?.finishReason, reason != .stop {
174-
throw GenerateContentError.responseStoppedEarly(reason: reason, response: response)
175-
}
176-
177-
// If all candidates are empty (contain no information that a developer could act on) then throw
178-
if response.candidates.allSatisfy({ $0.isEmpty }) {
179-
throw GenerateContentError.internalError(underlying: InvalidCandidateError.emptyContent(
180-
underlyingError: Candidate.EmptyContentError()
181-
))
182-
}
183-
184-
return response
147+
return try await generateContent(content, generationConfig: generationConfig)
185148
}
186149

187150
/// Generates content from String and/or image inputs, given to the model as a prompt, that are
@@ -357,6 +320,51 @@ public final class GenerativeModel: Sendable {
357320
return try await generativeAIService.loadRequest(request: countTokensRequest)
358321
}
359322

323+
// MARK: - Internal
324+
325+
public func generateContent(_ content: [ModelContent],
326+
generationConfig: GenerationConfig?) async throws
327+
-> GenerateContentResponse {
328+
try content.throwIfError()
329+
let response: GenerateContentResponse
330+
let generateContentRequest = GenerateContentRequest(
331+
model: modelResourceName,
332+
contents: content,
333+
generationConfig: generationConfig,
334+
safetySettings: safetySettings,
335+
tools: tools,
336+
toolConfig: toolConfig,
337+
systemInstruction: systemInstruction,
338+
apiConfig: apiConfig,
339+
apiMethod: .generateContent,
340+
options: requestOptions
341+
)
342+
do {
343+
response = try await generativeAIService.loadRequest(request: generateContentRequest)
344+
} catch {
345+
throw GenerativeModel.generateContentError(from: error)
346+
}
347+
348+
// Check the prompt feedback to see if the prompt was blocked.
349+
if response.promptFeedback?.blockReason != nil {
350+
throw GenerateContentError.promptBlocked(response: response)
351+
}
352+
353+
// Check to see if an error should be thrown for stop reason.
354+
if let reason = response.candidates.first?.finishReason, reason != .stop {
355+
throw GenerateContentError.responseStoppedEarly(reason: reason, response: response)
356+
}
357+
358+
// If all candidates are empty (contain no information that a developer could act on) then throw
359+
if response.candidates.allSatisfy({ $0.isEmpty }) {
360+
throw GenerateContentError.internalError(underlying: InvalidCandidateError.emptyContent(
361+
underlyingError: Candidate.EmptyContentError()
362+
))
363+
}
364+
365+
return response
366+
}
367+
360368
/// Returns a `GenerateContentError` (for public consumption) from an internal error.
361369
///
362370
/// If `error` is already a `GenerateContentError` the error is returned unchanged.
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
// Copyright 2026 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#if canImport(FoundationModels)
16+
import Foundation
17+
import FoundationModels
18+
19+
@available(iOS 26.0, macOS 26.0, *)
20+
@available(tvOS, unavailable)
21+
@available(watchOS, unavailable)
22+
public final class GenerativeModelSession: Sendable {
23+
let generativeModel: GenerativeModel
24+
25+
public init(model: GenerativeModel) {
26+
generativeModel = model
27+
}
28+
29+
@discardableResult
30+
public final nonisolated(nonsending)
31+
func respond(to prompt: PartsRepresentable..., options: GenerationConfig? = nil) async throws
32+
-> GenerativeModelSession.Response<String> {
33+
let parts = [ModelContent(parts: prompt)]
34+
35+
var config = GenerationConfig.merge(
36+
generativeModel.generationConfig, with: options
37+
) ?? GenerationConfig()
38+
config.responseModalities = nil // Override to the default (text only)
39+
config.candidateCount = nil // Override to the default (one candidate)
40+
41+
let response = try await generativeModel.generateContent(parts, generationConfig: config)
42+
guard let text = response.text else {
43+
throw GenerationError.decodingFailure(
44+
GenerationError.Context(debugDescription: "No text in response: \(response)")
45+
)
46+
}
47+
let generatedContent = GeneratedContent(kind: .string(text))
48+
49+
return GenerativeModelSession.Response(
50+
content: text, rawContent: generatedContent, rawResponse: response
51+
)
52+
}
53+
54+
@discardableResult
55+
public final nonisolated(nonsending)
56+
func respond(to prompt: PartsRepresentable..., schema: GenerationSchema,
57+
includeSchemaInPrompt: Bool = true, options: GenerationConfig? = nil) async throws
58+
-> GenerativeModelSession.Response<GeneratedContent> {
59+
let parts = [ModelContent(parts: prompt)]
60+
var config = GenerationConfig.merge(
61+
generativeModel.generationConfig, with: options
62+
) ?? GenerationConfig()
63+
config.responseMIMEType = "application/json"
64+
config.responseJSONSchema = includeSchemaInPrompt ? try schema.toGeminiJSONSchema() : nil
65+
config.responseSchema = nil // `responseSchema` must not be set with `responseJSONSchema`
66+
config.responseModalities = nil // Override to the default (text only)
67+
config.candidateCount = nil // Override to the default (one candidate)
68+
69+
let response = try await generativeModel.generateContent(parts, generationConfig: config)
70+
guard let text = response.text else {
71+
throw GenerationError.decodingFailure(
72+
GenerationError.Context(debugDescription: "No text in response: \(response)")
73+
)
74+
}
75+
let generatedContent = try GeneratedContent(json: text)
76+
77+
return GenerativeModelSession.Response(
78+
content: generatedContent, rawContent: generatedContent, rawResponse: response
79+
)
80+
}
81+
82+
@discardableResult
83+
public final nonisolated(nonsending)
84+
func respond<Content>(to prompt: PartsRepresentable...,
85+
generating type: Content.Type = Content.self,
86+
includeSchemaInPrompt: Bool = true,
87+
options: GenerationConfig? = nil) async throws
88+
-> GenerativeModelSession.Response<Content> where Content: Generable {
89+
let response = try await respond(
90+
to: prompt,
91+
schema: type.generationSchema,
92+
includeSchemaInPrompt: includeSchemaInPrompt,
93+
options: options
94+
)
95+
96+
let content = try Content(response.rawContent)
97+
98+
return GenerativeModelSession.Response(
99+
content: content, rawContent: response.rawContent, rawResponse: response.rawResponse
100+
)
101+
}
102+
103+
public struct Response<Content> where Content: Generable {
104+
public let content: Content
105+
public let rawContent: GeneratedContent
106+
public let rawResponse: GenerateContentResponse
107+
}
108+
109+
public enum GenerationError: Error, LocalizedError {
110+
public struct Context: Sendable {
111+
public let debugDescription: String
112+
113+
init(debugDescription: String) {
114+
self.debugDescription = debugDescription
115+
}
116+
}
117+
118+
case decodingFailure(GenerativeModelSession.GenerationError.Context)
119+
}
120+
}
121+
#endif // canImport(FoundationModels)

0 commit comments

Comments
 (0)