diff --git a/FirebaseAI/Tests/TestApp/Packages/FirebaseAILogicExtended/Sources/FirebaseAILogicMacros/FirebaseGenerableMacro.swift b/FirebaseAI/Tests/TestApp/Packages/FirebaseAILogicExtended/Sources/FirebaseAILogicMacros/FirebaseGenerableMacro.swift index cb91c51224f..9dc4347a024 100644 --- a/FirebaseAI/Tests/TestApp/Packages/FirebaseAILogicExtended/Sources/FirebaseAILogicMacros/FirebaseGenerableMacro.swift +++ b/FirebaseAI/Tests/TestApp/Packages/FirebaseAILogicExtended/Sources/FirebaseAILogicMacros/FirebaseGenerableMacro.swift @@ -105,25 +105,47 @@ public struct FirebaseGenerableMacro: ExtensionMacro { } } - let caseChecks = caseNames.map { caseName in - "case \"\(caseName)\":\nself = .\(caseName)" - }.joined(separator: "\n") + let unexpectedValue: CodeBlockItemSyntax = """ + throw FirebaseAILogic.GenerativeModel.GenerationError.decodingFailure( + FirebaseAILogic.GenerativeModel.GenerationError.Context( + debugDescription: "Unexpected value \\"\\(rawValue)\\" for \\(Self.self)" + ) + ) + """ + let coreInitBody: CodeBlockItemSyntax + if isStringBacked(enumDecl: enumDecl) { + coreInitBody = """ + if let value = Self(rawValue: rawValue) { + self = value + } else { + \(unexpectedValue) + } + """ + } else { + let caseChecks = caseNames.map { caseName in + "case \"\(caseName)\":\nself = .\(caseName)" + }.joined(separator: "\n") + + coreInitBody = """ + switch rawValue { + \(raw: caseChecks) + default: + \(unexpectedValue) + } + """ + } + + let initBody = """ + let rawValue = try content.value(String.self) + \(coreInitBody) + """ var declarations = [ExtensionDeclSyntax]() let declSyntaxString = """ @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) extension \(type.trimmed): FirebaseAILogic.FirebaseGenerable { nonisolated init(_ content: FirebaseAILogic.ModelOutput) throws { - let rawValue = try content.value(String.self) - switch rawValue { - \(caseChecks) - default: - throw FirebaseAILogic.GenerativeModel.GenerationError.decodingFailure( - FirebaseAILogic.GenerativeModel.GenerationError.Context( - debugDescription: "Unexpected value \\"\\(rawValue)\\" for \\(Self.self)" - ) - ) - } + \(initBody) } } """ @@ -137,6 +159,12 @@ public struct FirebaseGenerableMacro: ExtensionMacro { return declarations } + + private static func isStringBacked(enumDecl: EnumDeclSyntax) -> Bool { + return enumDecl.inheritanceClause?.inheritedTypes.contains { + $0.type.trimmed.description == "String" + } ?? false + } } struct PropertyInfo { @@ -167,7 +195,7 @@ extension FirebaseGenerableMacro: MemberMacro { private static func expansionForStruct(of node: SwiftSyntax.AttributeSyntax, structDecl: StructDeclSyntax) throws -> [SwiftSyntax.DeclSyntax] { - // Find the description for the struct itself from the @Generable macro. + // Find the description for the struct itself from the @FirebaseGenerable macro. let structDescription = try getDescriptionFromGenerableMacro(node) var propertyInfos = [PropertyInfo]() @@ -231,6 +259,7 @@ extension FirebaseGenerableMacro: MemberMacro { private static func expansionForEnum(of node: SwiftSyntax.AttributeSyntax, enumDecl: EnumDeclSyntax) throws -> [SwiftSyntax.DeclSyntax] { + var rawValues = [String]() var caseNames = [String]() for member in enumDecl.memberBlock.members { @@ -239,31 +268,61 @@ extension FirebaseGenerableMacro: MemberMacro { } for element in caseDecl.elements { - // Validation already happened in ExtensionMacro caseNames.append(element.name.text) + if let rawValueExpr = element.rawValue?.value.as(StringLiteralExprSyntax.self), + let segment = rawValueExpr.segments.first?.as(StringSegmentSyntax.self) { + rawValues.append(segment.content.text) + } else { + rawValues.append(element.name.text) + } } } + // Find the description for the enum itself from the @FirebaseGenerable macro. + let enumDescription = try getDescriptionFromGenerableMacro(node) + // Generate `static var jsonSchema: ...` computed property. - let anyOfList = caseNames.map { "\"\($0)\"" }.joined(separator: ", ") + let anyOfList: String + if isStringBacked(enumDecl: enumDecl) { + anyOfList = caseNames.map { "\($0).rawValue" }.joined(separator: ", ") + } else { + anyOfList = rawValues.map { "\"\($0)\"" }.joined(separator: ", ") + } + + var schemaParameters = ["type: Self.self"] + if let enumDescription { + schemaParameters.append("description: \"\(enumDescription)\"") + } + schemaParameters.append("anyOf: [\(anyOfList)]") + let schemaParametersCode = schemaParameters.joined(separator: ", ") + let generationSchemaCode = """ nonisolated static var jsonSchema: FirebaseAILogic.JSONSchema { - FirebaseAILogic.JSONSchema(type: Self.self, anyOf: [\(anyOfList)]) + FirebaseAILogic.JSONSchema(\(schemaParametersCode)) } """ // Generate `var modelOutput: ...` computed property. - let switchCases = caseNames.map { caseName in - "case .\(caseName):\n \"\(caseName)\".modelOutput" - }.joined(separator: "\n ") - - let modelOutputCode = """ - nonisolated var modelOutput: FirebaseAILogic.ModelOutput { - switch self { - \(switchCases) + let modelOutputCode: String + if isStringBacked(enumDecl: enumDecl) { + modelOutputCode = """ + nonisolated var modelOutput: FirebaseAILogic.ModelOutput { + rawValue.modelOutput } + """ + } else { + let switchCases = caseNames.map { caseName in + "case .\(caseName):\n \"\(caseName)\".modelOutput" + }.joined(separator: "\n ") + + modelOutputCode = """ + nonisolated var modelOutput: FirebaseAILogic.ModelOutput { + switch self { + \(switchCases) + } + } + """ } - """ return [ DeclSyntax(stringLiteral: generationSchemaCode), diff --git a/FirebaseAI/Tests/TestApp/Packages/FirebaseAILogicExtended/Tests/FirebaseGenerableMacroTests/FirebaseGenerableTests.swift b/FirebaseAI/Tests/TestApp/Packages/FirebaseAILogicExtended/Tests/FirebaseGenerableMacroTests/FirebaseGenerableTests.swift index 58b41813477..b00143cbd24 100644 --- a/FirebaseAI/Tests/TestApp/Packages/FirebaseAILogicExtended/Tests/FirebaseGenerableMacroTests/FirebaseGenerableTests.swift +++ b/FirebaseAI/Tests/TestApp/Packages/FirebaseAILogicExtended/Tests/FirebaseGenerableMacroTests/FirebaseGenerableTests.swift @@ -120,7 +120,7 @@ final class FirebaseAILogicMacrosTests: XCTestCase { #if canImport(FirebaseAILogicMacros) assertMacroExpansion( """ - @FirebaseGenerable + @FirebaseGenerable(description: "A type of pet") enum Pet { case cat case dog @@ -134,7 +134,7 @@ final class FirebaseAILogicMacrosTests: XCTestCase { case fish nonisolated static var jsonSchema: FirebaseAILogic.JSONSchema { - FirebaseAILogic.JSONSchema(type: Self.self, anyOf: ["cat", "dog", "fish"]) + FirebaseAILogic.JSONSchema(type: Self.self, description: "A type of pet", anyOf: ["cat", "dog", "fish"]) } nonisolated var modelOutput: FirebaseAILogic.ModelOutput { @@ -177,4 +177,54 @@ final class FirebaseAILogicMacrosTests: XCTestCase { throw XCTSkip("Macros are only supported when running tests for the host platform") #endif } + + func testEnumMacroWithRawValue() throws { + #if canImport(FirebaseAILogicMacros) + assertMacroExpansion( + """ + @FirebaseGenerable + enum Priority: String { + case high + case medium = "med" + case low + } + """, + expandedSource: """ + enum Priority: String { + case high + case medium = "med" + case low + + nonisolated static var jsonSchema: FirebaseAILogic.JSONSchema { + FirebaseAILogic.JSONSchema(type: Self.self, anyOf: [high.rawValue, medium.rawValue, low.rawValue]) + } + + nonisolated var modelOutput: FirebaseAILogic.ModelOutput { + rawValue.modelOutput + } + } + + @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) + extension Priority: FirebaseAILogic.FirebaseGenerable { + nonisolated init(_ content: FirebaseAILogic.ModelOutput) throws { + let rawValue = try content.value(String.self) + if let value = Self(rawValue: rawValue) { + self = value + } else { + throw FirebaseAILogic.GenerativeModel.GenerationError.decodingFailure( + FirebaseAILogic.GenerativeModel.GenerationError.Context( + debugDescription: "Unexpected value \\"\\(rawValue)\\" for \\(Self.self)" + ) + ) + } + } + } + """, + macros: testMacros, + indentationWidth: .spaces(2) + ) + #else + throw XCTSkip("Macros are only supported when running tests for the host platform") + #endif + } } diff --git a/FirebaseAI/Tests/TestApp/Tests/Integration/SchemaTests.swift b/FirebaseAI/Tests/TestApp/Tests/Integration/SchemaTests.swift index 276785293e5..284e30180ad 100644 --- a/FirebaseAI/Tests/TestApp/Tests/Integration/SchemaTests.swift +++ b/FirebaseAI/Tests/TestApp/Tests/Integration/SchemaTests.swift @@ -487,6 +487,66 @@ struct SchemaTests { #expect(userProfile.middleName == nil) } + @FirebaseGenerable + struct Pet { + let name: String + let species: Species + + @FirebaseGenerable(description: "Animal species types") + enum Species { + case cat, dog + } + } + + @Test(arguments: testConfigs( + instanceConfigs: InstanceConfig.allConfigs, + openAPISchema: .object( + properties: [ + "name": .string(), + "species": .enumeration( + values: ["cat", "dog"], + description: "Animal species types" + ), + ], + title: "Pet" + ), + jsonSchema: Pet.jsonSchema + )) + func generateContentSimpleStringEnum(_ config: InstanceConfig, + _ schema: SchemaType) async throws { + print(Pet.jsonSchema.debugDescription) + let model = FirebaseAI.componentInstance(config).generativeModel( + modelName: ModelNames.gemini2_5_FlashLite, + generationConfig: SchemaTests.generationConfig(schema: schema), + safetySettings: safetySettings + ) + let prompt = "Create a pet cat named 'Fluffy'." + + let response = try await model.generateContent(prompt) + + let text = try #require(response.text) + let modelOutput = try ModelOutput(json: text) + let pet = try Pet(modelOutput) + #expect(pet.name == "Fluffy") + #expect(pet.species == .cat) + } + + @Test(arguments: InstanceConfig.allConfigs) + func generateTypeSimpleStringEnum(_ config: InstanceConfig) async throws { + let model = FirebaseAI.componentInstance(config).generativeModel( + modelName: ModelNames.gemini2_5_FlashLite, + generationConfig: generationConfig, + safetySettings: safetySettings + ) + let prompt = "Create a pet dog named 'Buddy'." + + let response = try await model.generate(Pet.self, from: prompt) + + let pet = response.content + #expect(pet.name == "Buddy") + #expect(pet.species == .dog) + } + @FirebaseGenerable struct Task { let title: String @@ -495,8 +555,10 @@ struct SchemaTests { let priority: Priority @FirebaseGenerable - enum Priority { - case low, medium, high + enum Priority: String { + case low + case medium = "med" + case high } } @@ -506,7 +568,7 @@ struct SchemaTests { properties: [ "title": .string(), "priority": .enumeration( - values: ["low", "medium", "high"], + values: ["low", "med", "high"], description: "The priority level" ), ], @@ -514,7 +576,8 @@ struct SchemaTests { ), jsonSchema: Task.jsonSchema )) - func generateContentStringEnum(_ config: InstanceConfig, _ schema: SchemaType) async throws { + func generateContentStringRawValueEnum(_ config: InstanceConfig, + _ schema: SchemaType) async throws { let model = FirebaseAI.componentInstance(config).generativeModel( modelName: ModelNames.gemini2_5_FlashLite, generationConfig: SchemaTests.generationConfig(schema: schema), @@ -532,7 +595,7 @@ struct SchemaTests { } @Test(arguments: InstanceConfig.allConfigs) - func generateTypeStringEnum(_ config: InstanceConfig) async throws { + func generateTypeStringRawValueEnum(_ config: InstanceConfig) async throws { let model = FirebaseAI.componentInstance(config).generativeModel( modelName: ModelNames.gemini2_5_FlashLite, generationConfig: generationConfig,