diff --git a/FirebaseAI/Sources/GenerationConfig.swift b/FirebaseAI/Sources/GenerationConfig.swift index 27c4310f12d..fe2b6963e22 100644 --- a/FirebaseAI/Sources/GenerationConfig.swift +++ b/FirebaseAI/Sources/GenerationConfig.swift @@ -48,6 +48,11 @@ public struct GenerationConfig: Sendable { /// Output schema of the generated candidate text. let responseSchema: Schema? + /// Output schema of the generated response in [JSON Schema](https://json-schema.org/) format. + /// + /// If set, `responseSchema` must be omitted and `responseMIMEType` is required. + let responseJSONSchema: JSONObject? + /// Supported modalities of the response. let responseModalities: [ResponseModality]? @@ -175,6 +180,26 @@ public struct GenerationConfig: Sendable { self.stopSequences = stopSequences self.responseMIMEType = responseMIMEType self.responseSchema = responseSchema + responseJSONSchema = nil + self.responseModalities = responseModalities + self.thinkingConfig = thinkingConfig + } + + init(temperature: Float? = nil, topP: Float? = nil, topK: Int? = nil, candidateCount: Int? = nil, + maxOutputTokens: Int? = nil, presencePenalty: Float? = nil, frequencyPenalty: Float? = nil, + stopSequences: [String]? = nil, responseMIMEType: String, responseJSONSchema: JSONObject, + responseModalities: [ResponseModality]? = nil, thinkingConfig: ThinkingConfig? = nil) { + self.temperature = temperature + self.topP = topP + self.topK = topK + self.candidateCount = candidateCount + self.maxOutputTokens = maxOutputTokens + self.presencePenalty = presencePenalty + self.frequencyPenalty = frequencyPenalty + self.stopSequences = stopSequences + self.responseMIMEType = responseMIMEType + responseSchema = nil + self.responseJSONSchema = responseJSONSchema self.responseModalities = responseModalities self.thinkingConfig = thinkingConfig } @@ -195,6 +220,7 @@ extension GenerationConfig: Encodable { case stopSequences case responseMIMEType = "responseMimeType" case responseSchema + case responseJSONSchema = "responseJsonSchema" case responseModalities case thinkingConfig } diff --git a/FirebaseAI/Tests/TestApp/Tests/Integration/SchemaTests.swift b/FirebaseAI/Tests/TestApp/Tests/Integration/SchemaTests.swift index 4f4dd1e3dc8..f2f3d06441b 100644 --- a/FirebaseAI/Tests/TestApp/Tests/Integration/SchemaTests.swift +++ b/FirebaseAI/Tests/TestApp/Tests/Integration/SchemaTests.swift @@ -73,6 +73,34 @@ struct SchemaTests { #expect(decodedJSON.count <= 5, "Expected at most 5 cities, but got \(decodedJSON.count)") } + @Test(arguments: InstanceConfig.allConfigs) + func generateContentJSONSchemaItems(_ config: InstanceConfig) async throws { + let model = FirebaseAI.componentInstance(config).generativeModel( + modelName: ModelNames.gemini2_5_FlashLite, + generationConfig: GenerationConfig( + responseMIMEType: "application/json", + responseJSONSchema: [ + "type": .string("array"), + "description": .string("A list of city names"), + "items": .object([ + "type": .string("string"), + "description": .string("The name of the city"), + ]), + "minItems": .number(3), + "maxItems": .number(5), + ] + ), + safetySettings: safetySettings + ) + let prompt = "What are the biggest cities in Canada?" + let response = try await model.generateContent(prompt) + let text = try #require(response.text).trimmingCharacters(in: .whitespacesAndNewlines) + let jsonData = try #require(text.data(using: .utf8)) + let decodedJSON = try JSONDecoder().decode([String].self, from: jsonData) + #expect(decodedJSON.count >= 3, "Expected at least 3 cities, but got \(decodedJSON.count)") + #expect(decodedJSON.count <= 5, "Expected at most 5 cities, but got \(decodedJSON.count)") + } + @Test(arguments: InstanceConfig.allConfigs) func generateContentSchemaNumberRange(_ config: InstanceConfig) async throws { let model = FirebaseAI.componentInstance(config).generativeModel( @@ -96,14 +124,41 @@ struct SchemaTests { #expect(decodedNumber <= 120.0, "Expected a number <= 120, but got \(decodedNumber)") } + @Test(arguments: InstanceConfig.allConfigs) + func generateContentJSONSchemaNumberRange(_ config: InstanceConfig) async throws { + let model = FirebaseAI.componentInstance(config).generativeModel( + modelName: ModelNames.gemini2_5_FlashLite, + generationConfig: GenerationConfig( + responseMIMEType: "application/json", + responseJSONSchema: [ + "type": .string("integer"), + "description": .string("A number"), + "minimum": .number(110), + "maximum": .number(120), + ] + ), + safetySettings: safetySettings + ) + let prompt = "Give me a number" + + let response = try await model.generateContent(prompt) + + let text = try #require(response.text).trimmingCharacters(in: .whitespacesAndNewlines) + let jsonData = try #require(text.data(using: .utf8)) + let decodedNumber = try JSONDecoder().decode(Double.self, from: jsonData) + #expect(decodedNumber >= 110.0, "Expected a number >= 110, but got \(decodedNumber)") + #expect(decodedNumber <= 120.0, "Expected a number <= 120, but got \(decodedNumber)") + } + + private struct ProductInfo: Codable { + let productName: String + let rating: Int + let price: Double + let salePrice: Float + } + @Test(arguments: InstanceConfig.allConfigs) func generateContentSchemaNumberRangeMultiType(_ config: InstanceConfig) async throws { - struct ProductInfo: Codable { - let productName: String - let rating: Int // Will correspond to .integer in schema - let price: Double // Will correspond to .double in schema - let salePrice: Float // Will correspond to .float in schema - } let model = FirebaseAI.componentInstance(config).generativeModel( modelName: ModelNames.gemini2FlashLite, generationConfig: GenerationConfig( @@ -150,28 +205,95 @@ struct SchemaTests { } @Test(arguments: InstanceConfig.allConfigs) - func generateContentAnyOfSchema(_ config: InstanceConfig) async throws { - struct MailingAddress: Decodable { - let streetAddress: String - let city: String + func generateContentJSONSchemaNumberRangeMultiType(_ config: InstanceConfig) async throws { + let model = FirebaseAI.componentInstance(config).generativeModel( + modelName: ModelNames.gemini2_5_FlashLite, + generationConfig: GenerationConfig( + responseMIMEType: "application/json", + responseJSONSchema: [ + "type": .string("object"), + "title": .string("ProductInfo"), + "properties": .object([ + "productName": .object([ + "type": .string("string"), + "description": .string("The name of the product"), + ]), + "price": .object([ + "type": .string("number"), + "description": .string("A price"), + "minimum": .number(10.00), + "maximum": .number(120.00), + ]), + "salePrice": .object([ + "type": .string("number"), + "description": .string("A sale price"), + "minimum": .number(5.00), + "maximum": .number(90.00), + ]), + "rating": .object([ + "type": .string("integer"), + "description": .string("A rating"), + "minimum": .number(1), + "maximum": .number(5), + ]), + ]), + "required": .array([ + .string("productName"), + .string("price"), + .string("salePrice"), + .string("rating"), + ]), + "propertyOrdering": .array([ + .string("salePrice"), + .string("rating"), + .string("price"), + .string("productName"), + ]), + ] + ), + safetySettings: safetySettings + ) + let prompt = "Describe a premium wireless headphone, including a user rating and price." - // Canadian-specific - let province: String? - let postalCode: String? + let response = try await model.generateContent(prompt) - // U.S.-specific - let state: String? - let zipCode: String? + let text = try #require(response.text).trimmingCharacters(in: .whitespacesAndNewlines) + let jsonData = try #require(text.data(using: .utf8)) + let decodedProduct = try JSONDecoder().decode(ProductInfo.self, from: jsonData) + let price = decodedProduct.price + let salePrice = decodedProduct.salePrice + let rating = decodedProduct.rating + #expect(price >= 10.0, "Expected a price >= 10.00, but got \(price)") + #expect(price <= 120.0, "Expected a price <= 120.00, but got \(price)") + #expect(salePrice >= 5.0, "Expected a salePrice >= 5.00, but got \(salePrice)") + #expect(salePrice <= 90.0, "Expected a salePrice <= 90.00, but got \(salePrice)") + #expect(rating >= 1, "Expected a rating >= 1, but got \(rating)") + #expect(rating <= 5, "Expected a rating <= 5, but got \(rating)") + } + + private struct MailingAddress: Decodable { + let streetAddress: String + let city: String + + // Canadian-specific + let province: String? + let postalCode: String? - var isCanadian: Bool { - return province != nil && postalCode != nil && state == nil && zipCode == nil - } + // U.S.-specific + let state: String? + let zipCode: String? - var isAmerican: Bool { - return province == nil && postalCode == nil && state != nil && zipCode != nil - } + var isCanadian: Bool { + return province != nil && postalCode != nil && state == nil && zipCode == nil } + var isAmerican: Bool { + return province == nil && postalCode == nil && state != nil && zipCode != nil + } + } + + @Test(arguments: InstanceConfig.allConfigs) + func generateContentAnyOfSchema(_ config: InstanceConfig) async throws { let streetSchema = Schema.string(description: "The civic number and street name, for example, '123 Main Street'.") let citySchema = Schema.string(description: "The name of the city.") @@ -232,4 +354,108 @@ struct SchemaTests { "Expected Canadian Queen's University address, got \(queensAddress)." ) } + + @Test(arguments: InstanceConfig.allConfigs) + func generateContentAnyOfJSONSchema(_ config: InstanceConfig) async throws { + let streetSchema: JSONValue = .object([ + "type": .string("string"), + "description": .string("The civic number and street name, for example, '123 Main Street'."), + ]) + let citySchema: JSONValue = .object([ + "type": .string("string"), + "description": .string("The name of the city."), + ]) + let canadianAddressSchema: JSONObject = [ + "type": .string("object"), + "description": .string("A Canadian mailing address"), + "properties": .object([ + "streetAddress": streetSchema, + "city": citySchema, + "province": .object([ + "type": .string("string"), + "description": .string( + "The 2-letter province or territory code, for example, 'ON', 'QC', or 'NU'." + ), + ]), + "postalCode": .object([ + "type": .string("string"), + "description": .string("The postal code, for example, 'A1A 1A1'."), + ]), + ]), + "required": .array([ + .string("streetAddress"), + .string("city"), + .string("province"), + .string("postalCode"), + ]), + ] + let americanAddressSchema: JSONObject = [ + "type": .string("object"), + "description": .string("A U.S. mailing address"), + "properties": .object([ + "streetAddress": streetSchema, + "city": citySchema, + "state": .object([ + "type": .string("string"), + "description": .string( + "The 2-letter U.S. state or territory code, for example, 'CA', 'NY', or 'TX'." + ), + ]), + "zipCode": .object([ + "type": .string("string"), + "description": .string("The 5-digit ZIP code, for example, '12345'."), + ]), + ]), + "required": .array([ + .string("streetAddress"), + .string("city"), + .string("state"), + .string("zipCode"), + ]), + ] + let model = FirebaseAI.componentInstance(config).generativeModel( + modelName: ModelNames.gemini2_5_Flash, + generationConfig: GenerationConfig( + temperature: 0.0, + topP: 0.0, + topK: 1, + responseMIMEType: "application/json", + responseJSONSchema: [ + "type": .string("array"), + "items": .object([ + "anyOf": .array([ + .object(canadianAddressSchema), + .object(americanAddressSchema), + ]), + ]), + ] + ), + safetySettings: safetySettings + ) + let prompt = """ + What are the mailing addresses for the University of Waterloo, UC Berkeley and Queen's U? + """ + + let response = try await model.generateContent(prompt) + + let text = try #require(response.text) + let jsonData = try #require(text.data(using: .utf8)) + let decodedAddresses = try JSONDecoder().decode([MailingAddress].self, from: jsonData) + try #require(decodedAddresses.count == 3, "Expected 3 JSON addresses, got \(text).") + let waterlooAddress = decodedAddresses[0] + #expect( + waterlooAddress.isCanadian, + "Expected Canadian University of Waterloo address, got \(waterlooAddress)." + ) + let berkeleyAddress = decodedAddresses[1] + #expect( + berkeleyAddress.isAmerican, + "Expected American UC Berkeley address, got \(berkeleyAddress)." + ) + let queensAddress = decodedAddresses[2] + #expect( + queensAddress.isCanadian, + "Expected Canadian Queen's University address, got \(queensAddress)." + ) + } } diff --git a/FirebaseAI/Tests/Unit/GenerationConfigTests.swift b/FirebaseAI/Tests/Unit/GenerationConfigTests.swift index 2b38d1898d4..c72fea56ed7 100644 --- a/FirebaseAI/Tests/Unit/GenerationConfigTests.swift +++ b/FirebaseAI/Tests/Unit/GenerationConfigTests.swift @@ -153,4 +153,85 @@ final class GenerationConfigTests: XCTestCase { } """) } + + func testEncodeGenerationConfig_responseJSONSchema() throws { + let mimeType = "application/json" + let responseJSONSchema: JSONObject = [ + "type": .string("object"), + "title": .string("Person"), + "properties": .object([ + "firstName": .object(["type": .string("string")]), + "middleNames": .object([ + "type": .string("array"), + "items": .object(["type": .string("string")]), + "minItems": .number(0), + "maxItems": .number(3), + ]), + "lastName": .object(["type": .string("string")]), + "age": .object(["type": .string("integer")]), + ]), + "required": .array([ + .string("firstName"), + .string("middleNames"), + .string("lastName"), + .string("age"), + ]), + "propertyOrdering": .array([ + .string("firstName"), + .string("middleNames"), + .string("lastName"), + .string("age"), + ]), + "additionalProperties": .bool(false), + ] + let generationConfig = GenerationConfig( + responseMIMEType: mimeType, + responseJSONSchema: responseJSONSchema + ) + + let jsonData = try encoder.encode(generationConfig) + + let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8)) + XCTAssertEqual(json, """ + { + "responseJsonSchema" : { + "additionalProperties" : false, + "properties" : { + "age" : { + "type" : "integer" + }, + "firstName" : { + "type" : "string" + }, + "lastName" : { + "type" : "string" + }, + "middleNames" : { + "items" : { + "type" : "string" + }, + "maxItems" : 3, + "minItems" : 0, + "type" : "array" + } + }, + "propertyOrdering" : [ + "firstName", + "middleNames", + "lastName", + "age" + ], + "required" : [ + "firstName", + "middleNames", + "lastName", + "age" + ], + "title" : "Person", + "type" : "object" + }, + "responseMimeType" : "\(mimeType)" + } + """) + } }