Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions FirebaseAI/Sources/GenerationConfig.swift
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ public struct GenerationConfig: Sendable {
/// Output schema of the generated candidate text.
let responseSchema: Schema?

/// Output schema of the generated response in [JSON Schema](https://json-schema.org/) format.
///
/// If set, `responseSchema` must be omitted and `responseMIMEType` is required.
let responseJSONSchema: JSONObject?

/// Supported modalities of the response.
let responseModalities: [ResponseModality]?

Expand Down Expand Up @@ -175,6 +180,26 @@ public struct GenerationConfig: Sendable {
self.stopSequences = stopSequences
self.responseMIMEType = responseMIMEType
self.responseSchema = responseSchema
responseJSONSchema = nil
self.responseModalities = responseModalities
self.thinkingConfig = thinkingConfig
}

init(temperature: Float? = nil, topP: Float? = nil, topK: Int? = nil, candidateCount: Int? = nil,
maxOutputTokens: Int? = nil, presencePenalty: Float? = nil, frequencyPenalty: Float? = nil,
stopSequences: [String]? = nil, responseMIMEType: String, responseJSONSchema: JSONObject,
responseModalities: [ResponseModality]? = nil, thinkingConfig: ThinkingConfig? = nil) {
self.temperature = temperature
self.topP = topP
self.topK = topK
self.candidateCount = candidateCount
self.maxOutputTokens = maxOutputTokens
self.presencePenalty = presencePenalty
self.frequencyPenalty = frequencyPenalty
self.stopSequences = stopSequences
self.responseMIMEType = responseMIMEType
responseSchema = nil
self.responseJSONSchema = responseJSONSchema
self.responseModalities = responseModalities
self.thinkingConfig = thinkingConfig
}
Expand All @@ -195,6 +220,7 @@ extension GenerationConfig: Encodable {
case stopSequences
case responseMIMEType = "responseMimeType"
case responseSchema
case responseJSONSchema = "responseJsonSchema"
case responseModalities
case thinkingConfig
}
Expand Down
270 changes: 248 additions & 22 deletions FirebaseAI/Tests/TestApp/Tests/Integration/SchemaTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,34 @@ struct SchemaTests {
#expect(decodedJSON.count <= 5, "Expected at most 5 cities, but got \(decodedJSON.count)")
}

@Test(arguments: InstanceConfig.allConfigs)
func generateContentJSONSchemaItems(_ config: InstanceConfig) async throws {
let model = FirebaseAI.componentInstance(config).generativeModel(
modelName: ModelNames.gemini2_5_FlashLite,
generationConfig: GenerationConfig(
responseMIMEType: "application/json",
responseJSONSchema: [
"type": .string("array"),
"description": .string("A list of city names"),
"items": .object([
"type": .string("string"),
"description": .string("The name of the city"),
]),
"minItems": .number(3),
"maxItems": .number(5),
]
),
safetySettings: safetySettings
)
let prompt = "What are the biggest cities in Canada?"
let response = try await model.generateContent(prompt)
let text = try #require(response.text).trimmingCharacters(in: .whitespacesAndNewlines)
let jsonData = try #require(text.data(using: .utf8))
let decodedJSON = try JSONDecoder().decode([String].self, from: jsonData)
#expect(decodedJSON.count >= 3, "Expected at least 3 cities, but got \(decodedJSON.count)")
#expect(decodedJSON.count <= 5, "Expected at most 5 cities, but got \(decodedJSON.count)")
}
Comment on lines +77 to +102
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's a lot of duplicated code between this test and generateContentSchemaItems. The logic for making the request and asserting the response is identical. Consider extracting this common logic into a private helper function to improve maintainability and reduce code duplication. This new helper could take a GenerativeModel instance as a parameter.

This same pattern of duplication exists for other new tests in this file (generateContentJSONSchemaNumberRange, generateContentJSONSchemaNumberRangeMultiType, generateContentAnyOfJSONSchema), and they would also benefit from a similar refactoring.


@Test(arguments: InstanceConfig.allConfigs)
func generateContentSchemaNumberRange(_ config: InstanceConfig) async throws {
let model = FirebaseAI.componentInstance(config).generativeModel(
Expand All @@ -96,14 +124,41 @@ struct SchemaTests {
#expect(decodedNumber <= 120.0, "Expected a number <= 120, but got \(decodedNumber)")
}

@Test(arguments: InstanceConfig.allConfigs)
func generateContentJSONSchemaNumberRange(_ config: InstanceConfig) async throws {
let model = FirebaseAI.componentInstance(config).generativeModel(
modelName: ModelNames.gemini2_5_FlashLite,
generationConfig: GenerationConfig(
responseMIMEType: "application/json",
responseJSONSchema: [
"type": .string("integer"),
"description": .string("A number"),
"minimum": .number(110),
"maximum": .number(120),
]
),
safetySettings: safetySettings
)
let prompt = "Give me a number"

let response = try await model.generateContent(prompt)

let text = try #require(response.text).trimmingCharacters(in: .whitespacesAndNewlines)
let jsonData = try #require(text.data(using: .utf8))
let decodedNumber = try JSONDecoder().decode(Double.self, from: jsonData)
#expect(decodedNumber >= 110.0, "Expected a number >= 110, but got \(decodedNumber)")
#expect(decodedNumber <= 120.0, "Expected a number <= 120, but got \(decodedNumber)")
}

private struct ProductInfo: Codable {
let productName: String
let rating: Int
let price: Double
let salePrice: Float
}

@Test(arguments: InstanceConfig.allConfigs)
func generateContentSchemaNumberRangeMultiType(_ config: InstanceConfig) async throws {
struct ProductInfo: Codable {
let productName: String
let rating: Int // Will correspond to .integer in schema
let price: Double // Will correspond to .double in schema
let salePrice: Float // Will correspond to .float in schema
}
let model = FirebaseAI.componentInstance(config).generativeModel(
modelName: ModelNames.gemini2FlashLite,
generationConfig: GenerationConfig(
Expand Down Expand Up @@ -150,28 +205,95 @@ struct SchemaTests {
}

@Test(arguments: InstanceConfig.allConfigs)
func generateContentAnyOfSchema(_ config: InstanceConfig) async throws {
struct MailingAddress: Decodable {
let streetAddress: String
let city: String
func generateContentJSONSchemaNumberRangeMultiType(_ config: InstanceConfig) async throws {
let model = FirebaseAI.componentInstance(config).generativeModel(
modelName: ModelNames.gemini2_5_FlashLite,
generationConfig: GenerationConfig(
responseMIMEType: "application/json",
responseJSONSchema: [
"type": .string("object"),
"title": .string("ProductInfo"),
"properties": .object([
"productName": .object([
"type": .string("string"),
"description": .string("The name of the product"),
]),
"price": .object([
"type": .string("number"),
"description": .string("A price"),
"minimum": .number(10.00),
"maximum": .number(120.00),
]),
"salePrice": .object([
"type": .string("number"),
"description": .string("A sale price"),
"minimum": .number(5.00),
"maximum": .number(90.00),
]),
"rating": .object([
"type": .string("integer"),
"description": .string("A rating"),
"minimum": .number(1),
"maximum": .number(5),
]),
]),
"required": .array([
.string("productName"),
.string("price"),
.string("salePrice"),
.string("rating"),
]),
"propertyOrdering": .array([
.string("salePrice"),
.string("rating"),
.string("price"),
.string("productName"),
]),
]
),
safetySettings: safetySettings
)
let prompt = "Describe a premium wireless headphone, including a user rating and price."

// Canadian-specific
let province: String?
let postalCode: String?
let response = try await model.generateContent(prompt)

// U.S.-specific
let state: String?
let zipCode: String?
let text = try #require(response.text).trimmingCharacters(in: .whitespacesAndNewlines)
let jsonData = try #require(text.data(using: .utf8))
let decodedProduct = try JSONDecoder().decode(ProductInfo.self, from: jsonData)
let price = decodedProduct.price
let salePrice = decodedProduct.salePrice
let rating = decodedProduct.rating
#expect(price >= 10.0, "Expected a price >= 10.00, but got \(price)")
#expect(price <= 120.0, "Expected a price <= 120.00, but got \(price)")
#expect(salePrice >= 5.0, "Expected a salePrice >= 5.00, but got \(salePrice)")
#expect(salePrice <= 90.0, "Expected a salePrice <= 90.00, but got \(salePrice)")
#expect(rating >= 1, "Expected a rating >= 1, but got \(rating)")
#expect(rating <= 5, "Expected a rating <= 5, but got \(rating)")
}

private struct MailingAddress: Decodable {
let streetAddress: String
let city: String

// Canadian-specific
let province: String?
let postalCode: String?

var isCanadian: Bool {
return province != nil && postalCode != nil && state == nil && zipCode == nil
}
// U.S.-specific
let state: String?
let zipCode: String?

var isAmerican: Bool {
return province == nil && postalCode == nil && state != nil && zipCode != nil
}
var isCanadian: Bool {
return province != nil && postalCode != nil && state == nil && zipCode == nil
}

var isAmerican: Bool {
return province == nil && postalCode == nil && state != nil && zipCode != nil
}
}

@Test(arguments: InstanceConfig.allConfigs)
func generateContentAnyOfSchema(_ config: InstanceConfig) async throws {
let streetSchema = Schema.string(description:
"The civic number and street name, for example, '123 Main Street'.")
let citySchema = Schema.string(description: "The name of the city.")
Expand Down Expand Up @@ -232,4 +354,108 @@ struct SchemaTests {
"Expected Canadian Queen's University address, got \(queensAddress)."
)
}

@Test(arguments: InstanceConfig.allConfigs)
func generateContentAnyOfJSONSchema(_ config: InstanceConfig) async throws {
let streetSchema: JSONValue = .object([
"type": .string("string"),
"description": .string("The civic number and street name, for example, '123 Main Street'."),
])
let citySchema: JSONValue = .object([
"type": .string("string"),
"description": .string("The name of the city."),
])
let canadianAddressSchema: JSONObject = [
"type": .string("object"),
"description": .string("A Canadian mailing address"),
"properties": .object([
"streetAddress": streetSchema,
"city": citySchema,
"province": .object([
"type": .string("string"),
"description": .string(
"The 2-letter province or territory code, for example, 'ON', 'QC', or 'NU'."
),
]),
"postalCode": .object([
"type": .string("string"),
"description": .string("The postal code, for example, 'A1A 1A1'."),
]),
]),
"required": .array([
.string("streetAddress"),
.string("city"),
.string("province"),
.string("postalCode"),
]),
]
let americanAddressSchema: JSONObject = [
"type": .string("object"),
"description": .string("A U.S. mailing address"),
"properties": .object([
"streetAddress": streetSchema,
"city": citySchema,
"state": .object([
"type": .string("string"),
"description": .string(
"The 2-letter U.S. state or territory code, for example, 'CA', 'NY', or 'TX'."
),
]),
"zipCode": .object([
"type": .string("string"),
"description": .string("The 5-digit ZIP code, for example, '12345'."),
]),
]),
"required": .array([
.string("streetAddress"),
.string("city"),
.string("state"),
.string("zipCode"),
]),
]
let model = FirebaseAI.componentInstance(config).generativeModel(
modelName: ModelNames.gemini2_5_Flash,
generationConfig: GenerationConfig(
temperature: 0.0,
topP: 0.0,
topK: 1,
responseMIMEType: "application/json",
responseJSONSchema: [
"type": .string("array"),
"items": .object([
"anyOf": .array([
.object(canadianAddressSchema),
.object(americanAddressSchema),
]),
]),
]
),
safetySettings: safetySettings
)
let prompt = """
What are the mailing addresses for the University of Waterloo, UC Berkeley and Queen's U?
"""

let response = try await model.generateContent(prompt)

let text = try #require(response.text)
let jsonData = try #require(text.data(using: .utf8))
let decodedAddresses = try JSONDecoder().decode([MailingAddress].self, from: jsonData)
try #require(decodedAddresses.count == 3, "Expected 3 JSON addresses, got \(text).")
let waterlooAddress = decodedAddresses[0]
#expect(
waterlooAddress.isCanadian,
"Expected Canadian University of Waterloo address, got \(waterlooAddress)."
)
let berkeleyAddress = decodedAddresses[1]
#expect(
berkeleyAddress.isAmerican,
"Expected American UC Berkeley address, got \(berkeleyAddress)."
)
let queensAddress = decodedAddresses[2]
#expect(
queensAddress.isCanadian,
"Expected Canadian Queen's University address, got \(queensAddress)."
)
}
}
Loading
Loading