Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ extension ResponseID.Identifier: Equatable {
} else if case let .generationID(lhsGenerationID) = lhs,
case let .generationID(rhsGenerationID) = rhs {
#if canImport(FoundationModels)
if #available(iOS 26.0, macOS 26.0, *) {
if #available(iOS 26.0, macOS 26.0, visionOS 26.0, *) {
guard let lhsGenerationID = lhsGenerationID as? GenerationID,
let rhsGenerationID = rhsGenerationID as? GenerationID else {
return false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,20 @@ public struct FirebaseGenerableMacro: ExtensionMacro {
in _: some SwiftSyntaxMacros
.MacroExpansionContext) throws
-> [SwiftSyntax.ExtensionDeclSyntax] {
guard let structDecl = declaration.as(StructDeclSyntax.self) else {
throw MacroExpansionErrorMessage("`@FirebaseGenerable` can only be applied to a struct.")
if let structDecl = declaration.as(StructDeclSyntax.self) {
return try expansionForStruct(of: structDecl, type: type)
} else if let enumDecl = declaration.as(EnumDeclSyntax.self) {
return try expansionForEnum(of: enumDecl, type: type)
} else {
throw MacroExpansionErrorMessage(
"`@FirebaseGenerable` can only be applied to a struct or enum."
)
}
}

private static func expansionForStruct(of structDecl: StructDeclSyntax,
type: some SwiftSyntax.TypeSyntaxProtocol) throws
-> [SwiftSyntax.ExtensionDeclSyntax] {
var propertyInits = [String]()

for member in structDecl.memberBlock.members {
Expand Down Expand Up @@ -66,8 +76,62 @@ public struct FirebaseGenerableMacro: ExtensionMacro {
}
"""
guard let extensionDecl = declSyntax.as(ExtensionDeclSyntax.self) else {
// TODO: Throw an error
return []
throw MacroExpansionErrorMessage("""
Failed to generate `FirebaseGenerable` extension for struct `\(type.trimmed)`.
""")
}
declarations.append(extensionDecl)

return declarations
}

private static func expansionForEnum(of enumDecl: EnumDeclSyntax,
type: some SwiftSyntax.TypeSyntaxProtocol) throws
-> [SwiftSyntax.ExtensionDeclSyntax] {
var caseNames = [String]()

for member in enumDecl.memberBlock.members {
guard let caseDecl = member.decl.as(EnumCaseDeclSyntax.self) else {
continue
}

for element in caseDecl.elements {
if element.parameterClause != nil {
throw MacroExpansionErrorMessage(
"`@FirebaseGenerable` does not currently support enums with associated values."
)
}
caseNames.append(element.name.text)
}
}

let caseChecks = caseNames.map { caseName in
"case \"\(caseName)\":\nself = .\(caseName)"
}.joined(separator: "\n")

var declarations = [ExtensionDeclSyntax]()
let declSyntaxString = """
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension \(type.trimmed): FirebaseAILogic.FirebaseGenerable {
nonisolated init(_ content: FirebaseAILogic.ModelOutput) throws {
let rawValue = try content.value(String.self)
switch rawValue {
\(caseChecks)
default:
throw FirebaseAILogic.GenerativeModel.GenerationError.decodingFailure(
FirebaseAILogic.GenerativeModel.GenerationError.Context(
debugDescription: "Unexpected value \\"\\(rawValue)\\" for \\(Self.self)"
)
)
}
}
}
"""
let declSyntax = DeclSyntax(stringLiteral: declSyntaxString)
guard let extensionDecl = declSyntax.as(ExtensionDeclSyntax.self) else {
throw MacroExpansionErrorMessage("""
Failed to generate `FirebaseGenerable` extension for enum `\(type.trimmed)`.
""")
}
declarations.append(extensionDecl)

Expand All @@ -89,11 +153,20 @@ extension FirebaseGenerableMacro: MemberMacro {
conformingTo _: [SwiftSyntax.TypeSyntax],
in _: some SwiftSyntaxMacros
.MacroExpansionContext) throws -> [SwiftSyntax.DeclSyntax] {
// Ensure the macro is attached to a struct declaration.
guard let structDecl = declaration.as(StructDeclSyntax.self) else {
throw MacroExpansionErrorMessage("`@Generable` can only be applied to a struct.")
if let structDecl = declaration.as(StructDeclSyntax.self) {
return try expansionForStruct(of: node, structDecl: structDecl)
} else if let enumDecl = declaration.as(EnumDeclSyntax.self) {
return try expansionForEnum(of: node, enumDecl: enumDecl)
} else {
throw MacroExpansionErrorMessage(
"`@FirebaseGenerable` can only be applied to a struct or enum."
)
}
}

private static func expansionForStruct(of node: SwiftSyntax.AttributeSyntax,
structDecl: StructDeclSyntax) throws
-> [SwiftSyntax.DeclSyntax] {
// Find the description for the struct itself from the @Generable macro.
let structDescription = try getDescriptionFromGenerableMacro(node)

Expand Down Expand Up @@ -149,16 +222,57 @@ extension FirebaseGenerableMacro: MemberMacro {
))
}

let declarations = generateMembers(
return generateStructMembers(
structDescription: structDescription,
propertyInfos: propertyInfos
)
}

return declarations
private static func expansionForEnum(of node: SwiftSyntax.AttributeSyntax,
enumDecl: EnumDeclSyntax) throws
-> [SwiftSyntax.DeclSyntax] {
var caseNames = [String]()

for member in enumDecl.memberBlock.members {
guard let caseDecl = member.decl.as(EnumCaseDeclSyntax.self) else {
continue
}

for element in caseDecl.elements {
// Validation already happened in ExtensionMacro
caseNames.append(element.name.text)
}
}

// Generate `static var jsonSchema: ...` computed property.
let anyOfList = caseNames.map { "\"\($0)\"" }.joined(separator: ", ")
let generationSchemaCode = """
nonisolated static var jsonSchema: FirebaseAILogic.JSONSchema {
FirebaseAILogic.JSONSchema(type: Self.self, anyOf: [\(anyOfList)])
}
"""

// Generate `var modelOutput: ...` computed property.
let switchCases = caseNames.map { caseName in
"case .\(caseName):\n \"\(caseName)\".modelOutput"
}.joined(separator: "\n ")

let modelOutputCode = """
nonisolated var modelOutput: FirebaseAILogic.ModelOutput {
switch self {
\(switchCases)
}
}
"""

return [
DeclSyntax(stringLiteral: generationSchemaCode),
DeclSyntax(stringLiteral: modelOutputCode),
]
}

private static func generateMembers(structDescription: String?,
propertyInfos: [PropertyInfo]) -> [DeclSyntax] {
private static func generateStructMembers(structDescription: String?,
propertyInfos: [PropertyInfo]) -> [DeclSyntax] {
var propertyNames = [String]()
var propertySchemas = [String]()
var partiallyGeneratedProperties = [String]()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,54 +48,54 @@ final class FirebaseAILogicMacrosTests: XCTestCase {
let lastName: String
let age: Int

nonisolated static var jsonSchema: FirebaseAILogic.JSONSchema {
FirebaseAILogic.JSONSchema(
type: Self.self,
properties: [
FirebaseAILogic.JSONSchema.Property(name: "firstName", type: String.self),
FirebaseAILogic.JSONSchema.Property(name: "middleName", type: String?.self),
FirebaseAILogic.JSONSchema.Property(name: "lastName", type: String.self),
FirebaseAILogic.JSONSchema.Property(name: "age", type: Int.self)
]
)
}
nonisolated static var jsonSchema: FirebaseAILogic.JSONSchema {
FirebaseAILogic.JSONSchema(
type: Self.self,
properties: [
FirebaseAILogic.JSONSchema.Property(name: "firstName", type: String.self),
FirebaseAILogic.JSONSchema.Property(name: "middleName", type: String?.self),
FirebaseAILogic.JSONSchema.Property(name: "lastName", type: String.self),
FirebaseAILogic.JSONSchema.Property(name: "age", type: Int.self)
]
)
}

nonisolated var modelOutput: FirebaseAILogic.ModelOutput {
var properties = [(name: String, value: any ConvertibleToModelOutput)]()
addProperty(name: "firstName", value: self.firstName)
addProperty(name: "middleName", value: self.middleName)
addProperty(name: "lastName", value: self.lastName)
addProperty(name: "age", value: self.age)
return ModelOutput(
properties: properties,
uniquingKeysWith: { _, second in
second
}
)
func addProperty(name: String, value: some FirebaseGenerable) {
properties.append((name, value))
nonisolated var modelOutput: FirebaseAILogic.ModelOutput {
var properties = [(name: String, value: any ConvertibleToModelOutput)]()
addProperty(name: "firstName", value: self.firstName)
addProperty(name: "middleName", value: self.middleName)
addProperty(name: "lastName", value: self.lastName)
addProperty(name: "age", value: self.age)
return ModelOutput(
properties: properties,
uniquingKeysWith: { _, second in
second
}
func addProperty(name: String, value: (some FirebaseGenerable)?) {
if let value {
properties.append((name, value))
}
)
func addProperty(name: String, value: some FirebaseGenerable) {
properties.append((name, value))
}
func addProperty(name: String, value: (some FirebaseGenerable)?) {
if let value {
properties.append((name, value))
}
}
}

nonisolated struct Partial: Identifiable, FirebaseAILogic.ConvertibleFromModelOutput {
var id: FirebaseAILogic.ResponseID
var firstName: String.Partial?
var middleName: String?.Partial?
var lastName: String.Partial?
var age: Int.Partial?
nonisolated init(_ content: FirebaseAILogic.ModelOutput) throws {
self.id = content.id ?? FirebaseAILogic.ResponseID()
self.firstName = try content.value(forProperty: "firstName")
self.middleName = try content.value(forProperty: "middleName")
self.lastName = try content.value(forProperty: "lastName")
self.age = try content.value(forProperty: "age")
}
nonisolated struct Partial: Identifiable, FirebaseAILogic.ConvertibleFromModelOutput {
var id: FirebaseAILogic.ResponseID
var firstName: String.Partial?
var middleName: String?.Partial?
var lastName: String.Partial?
var age: Int.Partial?
nonisolated init(_ content: FirebaseAILogic.ModelOutput) throws {
self.id = content.id ?? FirebaseAILogic.ResponseID()
self.firstName = try content.value(forProperty: "firstName")
self.middleName = try content.value(forProperty: "middleName")
self.lastName = try content.value(forProperty: "lastName")
self.age = try content.value(forProperty: "age")
}
}
}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
Expand All @@ -108,7 +108,70 @@ final class FirebaseAILogicMacrosTests: XCTestCase {
}
}
""",
macros: testMacros
macros: testMacros,
indentationWidth: .spaces(2)
)
#else
throw XCTSkip("Macros are only supported when running tests for the host platform")
#endif
}

func testEnumMacro() throws {
#if canImport(FirebaseAILogicMacros)
assertMacroExpansion(
"""
@FirebaseGenerable
enum Pet {
case cat
case dog
case fish
}
""",
expandedSource: """
enum Pet {
case cat
case dog
case fish

nonisolated static var jsonSchema: FirebaseAILogic.JSONSchema {
FirebaseAILogic.JSONSchema(type: Self.self, anyOf: ["cat", "dog", "fish"])
}

nonisolated var modelOutput: FirebaseAILogic.ModelOutput {
switch self {
case .cat:
"cat".modelOutput
case .dog:
"dog".modelOutput
case .fish:
"fish".modelOutput
}
}
}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension Pet: FirebaseAILogic.FirebaseGenerable {
nonisolated init(_ content: FirebaseAILogic.ModelOutput) throws {
let rawValue = try content.value(String.self)
switch rawValue {
case "cat":
self = .cat
case "dog":
self = .dog
case "fish":
self = .fish
default:
throw FirebaseAILogic.GenerativeModel.GenerationError.decodingFailure(
FirebaseAILogic.GenerativeModel.GenerationError.Context(
debugDescription: "Unexpected value \\"\\(rawValue)\\" for \\(Self.self)"
)
)
}
}
}
""",
macros: testMacros,
indentationWidth: .spaces(2)
)
#else
throw XCTSkip("Macros are only supported when running tests for the host platform")
Expand Down
18 changes: 12 additions & 6 deletions FirebaseAI/Tests/TestApp/Tests/Integration/SchemaTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -490,8 +490,14 @@ struct SchemaTests {
@FirebaseGenerable
struct Task {
let title: String
@FirebaseGuide(description: "The priority level", .anyOf(["low", "medium", "high"]))
let priority: String

@FirebaseGuide(description: "The priority level")
let priority: Priority

@FirebaseGenerable
enum Priority {
case low, medium, high
}
}

@Test(arguments: testConfigs(
Expand All @@ -514,15 +520,15 @@ struct SchemaTests {
generationConfig: SchemaTests.generationConfig(schema: schema),
safetySettings: safetySettings
)
let prompt = "Create a high priority task titled 'Fix Bug'."
let prompt = "Create a medium priority task titled 'Feature Request'."

let response = try await model.generateContent(prompt)

let text = try #require(response.text)
let modelOutput = try ModelOutput(json: text)
let task = try Task(modelOutput)
#expect(task.title == "Fix Bug")
#expect(task.priority == "high")
#expect(task.title == "Feature Request")
#expect(task.priority == .medium)
}

@Test(arguments: InstanceConfig.allConfigs)
Expand All @@ -538,7 +544,7 @@ struct SchemaTests {

let task = response.content
#expect(task.title == "Fix Bug")
#expect(task.priority == "high")
#expect(task.priority == .high)
}

@FirebaseGenerable
Expand Down
Loading