Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -105,25 +105,47 @@ public struct FirebaseGenerableMacro: ExtensionMacro {
}
}

let caseChecks = caseNames.map { caseName in
"case \"\(caseName)\":\nself = .\(caseName)"
}.joined(separator: "\n")
let unexpectedValue: CodeBlockItemSyntax = """
throw FirebaseAILogic.GenerativeModel.GenerationError.decodingFailure(
FirebaseAILogic.GenerativeModel.GenerationError.Context(
debugDescription: "Unexpected value \\"\\(rawValue)\\" for \\(Self.self)"
)
)
"""
let coreInitBody: CodeBlockItemSyntax
if isStringBacked(enumDecl: enumDecl) {
coreInitBody = """
if let value = Self(rawValue: rawValue) {
self = value
} else {
\(unexpectedValue)
}
"""
} else {
let caseChecks = caseNames.map { caseName in
"case \"\(caseName)\":\nself = .\(caseName)"
}.joined(separator: "\n")

coreInitBody = """
switch rawValue {
\(raw: caseChecks)
default:
\(unexpectedValue)
}
"""
}

let initBody = """
let rawValue = try content.value(String.self)
\(coreInitBody)
"""

var declarations = [ExtensionDeclSyntax]()
let declSyntaxString = """
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension \(type.trimmed): FirebaseAILogic.FirebaseGenerable {
nonisolated init(_ content: FirebaseAILogic.ModelOutput) throws {
let rawValue = try content.value(String.self)
switch rawValue {
\(caseChecks)
default:
throw FirebaseAILogic.GenerativeModel.GenerationError.decodingFailure(
FirebaseAILogic.GenerativeModel.GenerationError.Context(
debugDescription: "Unexpected value \\"\\(rawValue)\\" for \\(Self.self)"
)
)
}
\(initBody)
}
}
"""
Expand All @@ -137,6 +159,12 @@ public struct FirebaseGenerableMacro: ExtensionMacro {

return declarations
}

private static func isStringBacked(enumDecl: EnumDeclSyntax) -> Bool {
return enumDecl.inheritanceClause?.inheritedTypes.contains {
$0.type.trimmed.description == "String"
} ?? false
}
}

struct PropertyInfo {
Expand Down Expand Up @@ -167,7 +195,7 @@ extension FirebaseGenerableMacro: MemberMacro {
private static func expansionForStruct(of node: SwiftSyntax.AttributeSyntax,
structDecl: StructDeclSyntax) throws
-> [SwiftSyntax.DeclSyntax] {
// Find the description for the struct itself from the @Generable macro.
// Find the description for the struct itself from the @FirebaseGenerable macro.
let structDescription = try getDescriptionFromGenerableMacro(node)

var propertyInfos = [PropertyInfo]()
Expand Down Expand Up @@ -231,6 +259,7 @@ extension FirebaseGenerableMacro: MemberMacro {
private static func expansionForEnum(of node: SwiftSyntax.AttributeSyntax,
enumDecl: EnumDeclSyntax) throws
-> [SwiftSyntax.DeclSyntax] {
var rawValues = [String]()
var caseNames = [String]()

for member in enumDecl.memberBlock.members {
Expand All @@ -239,31 +268,61 @@ extension FirebaseGenerableMacro: MemberMacro {
}

for element in caseDecl.elements {
// Validation already happened in ExtensionMacro
caseNames.append(element.name.text)
if let rawValueExpr = element.rawValue?.value.as(StringLiteralExprSyntax.self),
let segment = rawValueExpr.segments.first?.as(StringSegmentSyntax.self) {
rawValues.append(segment.content.text)
} else {
rawValues.append(element.name.text)
}
}
}

// Find the description for the enum itself from the @FirebaseGenerable macro.
let enumDescription = try getDescriptionFromGenerableMacro(node)

// Generate `static var jsonSchema: ...` computed property.
let anyOfList = caseNames.map { "\"\($0)\"" }.joined(separator: ", ")
let anyOfList: String
if isStringBacked(enumDecl: enumDecl) {
anyOfList = caseNames.map { "\($0).rawValue" }.joined(separator: ", ")
} else {
anyOfList = rawValues.map { "\"\($0)\"" }.joined(separator: ", ")
}

var schemaParameters = ["type: Self.self"]
if let enumDescription {
schemaParameters.append("description: \"\(enumDescription)\"")
}
schemaParameters.append("anyOf: [\(anyOfList)]")
let schemaParametersCode = schemaParameters.joined(separator: ", ")

let generationSchemaCode = """
nonisolated static var jsonSchema: FirebaseAILogic.JSONSchema {
FirebaseAILogic.JSONSchema(type: Self.self, anyOf: [\(anyOfList)])
FirebaseAILogic.JSONSchema(\(schemaParametersCode))
}
"""

// Generate `var modelOutput: ...` computed property.
let switchCases = caseNames.map { caseName in
"case .\(caseName):\n \"\(caseName)\".modelOutput"
}.joined(separator: "\n ")

let modelOutputCode = """
nonisolated var modelOutput: FirebaseAILogic.ModelOutput {
switch self {
\(switchCases)
let modelOutputCode: String
if isStringBacked(enumDecl: enumDecl) {
modelOutputCode = """
nonisolated var modelOutput: FirebaseAILogic.ModelOutput {
rawValue.modelOutput
}
"""
} else {
let switchCases = caseNames.map { caseName in
"case .\(caseName):\n \"\(caseName)\".modelOutput"
}.joined(separator: "\n ")

modelOutputCode = """
nonisolated var modelOutput: FirebaseAILogic.ModelOutput {
switch self {
\(switchCases)
}
}
"""
}
"""

return [
DeclSyntax(stringLiteral: generationSchemaCode),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ final class FirebaseAILogicMacrosTests: XCTestCase {
#if canImport(FirebaseAILogicMacros)
assertMacroExpansion(
"""
@FirebaseGenerable
@FirebaseGenerable(description: "A type of pet")
enum Pet {
case cat
case dog
Expand All @@ -134,7 +134,7 @@ final class FirebaseAILogicMacrosTests: XCTestCase {
case fish

nonisolated static var jsonSchema: FirebaseAILogic.JSONSchema {
FirebaseAILogic.JSONSchema(type: Self.self, anyOf: ["cat", "dog", "fish"])
FirebaseAILogic.JSONSchema(type: Self.self, description: "A type of pet", anyOf: ["cat", "dog", "fish"])
}

nonisolated var modelOutput: FirebaseAILogic.ModelOutput {
Expand Down Expand Up @@ -177,4 +177,54 @@ final class FirebaseAILogicMacrosTests: XCTestCase {
throw XCTSkip("Macros are only supported when running tests for the host platform")
#endif
}

func testEnumMacroWithRawValue() throws {
#if canImport(FirebaseAILogicMacros)
assertMacroExpansion(
"""
@FirebaseGenerable
enum Priority: String {
case high
case medium = "med"
case low
}
""",
expandedSource: """
enum Priority: String {
case high
case medium = "med"
case low

nonisolated static var jsonSchema: FirebaseAILogic.JSONSchema {
FirebaseAILogic.JSONSchema(type: Self.self, anyOf: [high.rawValue, medium.rawValue, low.rawValue])
}

nonisolated var modelOutput: FirebaseAILogic.ModelOutput {
rawValue.modelOutput
}
}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension Priority: FirebaseAILogic.FirebaseGenerable {
nonisolated init(_ content: FirebaseAILogic.ModelOutput) throws {
let rawValue = try content.value(String.self)
if let value = Self(rawValue: rawValue) {
self = value
} else {
throw FirebaseAILogic.GenerativeModel.GenerationError.decodingFailure(
FirebaseAILogic.GenerativeModel.GenerationError.Context(
debugDescription: "Unexpected value \\"\\(rawValue)\\" for \\(Self.self)"
)
)
}
}
}
""",
macros: testMacros,
indentationWidth: .spaces(2)
)
#else
throw XCTSkip("Macros are only supported when running tests for the host platform")
#endif
}
}
73 changes: 68 additions & 5 deletions FirebaseAI/Tests/TestApp/Tests/Integration/SchemaTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,66 @@ struct SchemaTests {
#expect(userProfile.middleName == nil)
}

@FirebaseGenerable
struct Pet {
let name: String
let species: Species

@FirebaseGenerable(description: "Animal species types")
enum Species {
case cat, dog
}
}

@Test(arguments: testConfigs(
instanceConfigs: InstanceConfig.allConfigs,
openAPISchema: .object(
properties: [
"name": .string(),
"species": .enumeration(
values: ["cat", "dog"],
description: "Animal species types"
),
],
title: "Pet"
),
jsonSchema: Pet.jsonSchema
))
func generateContentSimpleStringEnum(_ config: InstanceConfig,
_ schema: SchemaType) async throws {
print(Pet.jsonSchema.debugDescription)
let model = FirebaseAI.componentInstance(config).generativeModel(
modelName: ModelNames.gemini2_5_FlashLite,
generationConfig: SchemaTests.generationConfig(schema: schema),
safetySettings: safetySettings
)
let prompt = "Create a pet cat named 'Fluffy'."

let response = try await model.generateContent(prompt)

let text = try #require(response.text)
let modelOutput = try ModelOutput(json: text)
let pet = try Pet(modelOutput)
#expect(pet.name == "Fluffy")
#expect(pet.species == .cat)
}

@Test(arguments: InstanceConfig.allConfigs)
func generateTypeSimpleStringEnum(_ config: InstanceConfig) async throws {
let model = FirebaseAI.componentInstance(config).generativeModel(
modelName: ModelNames.gemini2_5_FlashLite,
generationConfig: generationConfig,
safetySettings: safetySettings
)
let prompt = "Create a pet dog named 'Buddy'."

let response = try await model.generate(Pet.self, from: prompt)

let pet = response.content
#expect(pet.name == "Buddy")
#expect(pet.species == .dog)
}

@FirebaseGenerable
struct Task {
let title: String
Expand All @@ -495,8 +555,10 @@ struct SchemaTests {
let priority: Priority

@FirebaseGenerable
enum Priority {
case low, medium, high
enum Priority: String {
case low
case medium = "med"
case high
}
}

Expand All @@ -506,15 +568,16 @@ struct SchemaTests {
properties: [
"title": .string(),
"priority": .enumeration(
values: ["low", "medium", "high"],
values: ["low", "med", "high"],
description: "The priority level"
),
],
title: "Task"
),
jsonSchema: Task.jsonSchema
))
func generateContentStringEnum(_ config: InstanceConfig, _ schema: SchemaType) async throws {
func generateContentStringRawValueEnum(_ config: InstanceConfig,
_ schema: SchemaType) async throws {
let model = FirebaseAI.componentInstance(config).generativeModel(
modelName: ModelNames.gemini2_5_FlashLite,
generationConfig: SchemaTests.generationConfig(schema: schema),
Expand All @@ -532,7 +595,7 @@ struct SchemaTests {
}

@Test(arguments: InstanceConfig.allConfigs)
func generateTypeStringEnum(_ config: InstanceConfig) async throws {
func generateTypeStringRawValueEnum(_ config: InstanceConfig) async throws {
let model = FirebaseAI.componentInstance(config).generativeModel(
modelName: ModelNames.gemini2_5_FlashLite,
generationConfig: generationConfig,
Expand Down
Loading