Skip to content

Commit 62f2b87

Browse files
committed
Add support for simple enums (without associated values)
1 parent 022f47d commit 62f2b87

File tree

3 files changed

+238
-58
lines changed

3 files changed

+238
-58
lines changed

FirebaseAI/Tests/TestApp/Packages/FirebaseAILogicExtended/Sources/FirebaseAILogicMacros/FirebaseGenerableMacro.swift

Lines changed: 120 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,20 @@ public struct FirebaseGenerableMacro: ExtensionMacro {
2626
in _: some SwiftSyntaxMacros
2727
.MacroExpansionContext) throws
2828
-> [SwiftSyntax.ExtensionDeclSyntax] {
29-
guard let structDecl = declaration.as(StructDeclSyntax.self) else {
30-
throw MacroExpansionErrorMessage("`@FirebaseGenerable` can only be applied to a struct.")
29+
if let structDecl = declaration.as(StructDeclSyntax.self) {
30+
return try expansionForStruct(of: structDecl, type: type)
31+
} else if let enumDecl = declaration.as(EnumDeclSyntax.self) {
32+
return try expansionForEnum(of: enumDecl, type: type)
33+
} else {
34+
throw MacroExpansionErrorMessage(
35+
"`@FirebaseGenerable` can only be applied to a struct or enum."
36+
)
3137
}
38+
}
3239

40+
private static func expansionForStruct(of structDecl: StructDeclSyntax,
41+
type: some SwiftSyntax.TypeSyntaxProtocol) throws
42+
-> [SwiftSyntax.ExtensionDeclSyntax] {
3343
var propertyInits = [String]()
3444

3545
for member in structDecl.memberBlock.members {
@@ -73,6 +83,57 @@ public struct FirebaseGenerableMacro: ExtensionMacro {
7383

7484
return declarations
7585
}
86+
87+
private static func expansionForEnum(of enumDecl: EnumDeclSyntax,
88+
type: some SwiftSyntax.TypeSyntaxProtocol) throws
89+
-> [SwiftSyntax.ExtensionDeclSyntax] {
90+
var caseNames = [String]()
91+
92+
for member in enumDecl.memberBlock.members {
93+
guard let caseDecl = member.decl.as(EnumCaseDeclSyntax.self) else {
94+
continue
95+
}
96+
97+
for element in caseDecl.elements {
98+
if element.parameterClause != nil {
99+
throw MacroExpansionErrorMessage(
100+
"`@FirebaseGenerable` does not currently support enums with associated values."
101+
)
102+
}
103+
caseNames.append(element.name.text)
104+
}
105+
}
106+
107+
let caseChecks = caseNames.map { caseName in
108+
"case \"\(caseName)\":\nself = .\(caseName)"
109+
}.joined(separator: "\n")
110+
111+
var declarations = [ExtensionDeclSyntax]()
112+
let declSyntaxString = """
113+
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
114+
extension \(type.trimmed): FirebaseAILogic.FirebaseGenerable {
115+
nonisolated init(_ content: FirebaseAILogic.ModelOutput) throws {
116+
let rawValue = try content.value(String.self)
117+
switch rawValue {
118+
\(caseChecks)
119+
default:
120+
throw FirebaseAILogic.GenerativeModel.GenerationError.decodingFailure(
121+
FirebaseAILogic.GenerativeModel.GenerationError.Context(
122+
debugDescription: "Unexpected value \\"\\(rawValue)\\" for \\(Self.self)"
123+
)
124+
)
125+
}
126+
}
127+
}
128+
"""
129+
let declSyntax = DeclSyntax(stringLiteral: declSyntaxString)
130+
guard let extensionDecl = declSyntax.as(ExtensionDeclSyntax.self) else {
131+
return []
132+
}
133+
declarations.append(extensionDecl)
134+
135+
return declarations
136+
}
76137
}
77138

78139
struct PropertyInfo {
@@ -89,11 +150,20 @@ extension FirebaseGenerableMacro: MemberMacro {
89150
conformingTo _: [SwiftSyntax.TypeSyntax],
90151
in _: some SwiftSyntaxMacros
91152
.MacroExpansionContext) throws -> [SwiftSyntax.DeclSyntax] {
92-
// Ensure the macro is attached to a struct declaration.
93-
guard let structDecl = declaration.as(StructDeclSyntax.self) else {
94-
throw MacroExpansionErrorMessage("`@Generable` can only be applied to a struct.")
153+
if let structDecl = declaration.as(StructDeclSyntax.self) {
154+
return try expansionForStruct(of: node, structDecl: structDecl)
155+
} else if let enumDecl = declaration.as(EnumDeclSyntax.self) {
156+
return try expansionForEnum(of: node, enumDecl: enumDecl)
157+
} else {
158+
throw MacroExpansionErrorMessage(
159+
"`@FirebaseGenerable` can only be applied to a struct or enum."
160+
)
95161
}
162+
}
96163

164+
private static func expansionForStruct(of node: SwiftSyntax.AttributeSyntax,
165+
structDecl: StructDeclSyntax) throws
166+
-> [SwiftSyntax.DeclSyntax] {
97167
// Find the description for the struct itself from the @Generable macro.
98168
let structDescription = try getDescriptionFromGenerableMacro(node)
99169

@@ -149,16 +219,57 @@ extension FirebaseGenerableMacro: MemberMacro {
149219
))
150220
}
151221

152-
let declarations = generateMembers(
222+
return generateStructMembers(
153223
structDescription: structDescription,
154224
propertyInfos: propertyInfos
155225
)
226+
}
156227

157-
return declarations
228+
private static func expansionForEnum(of node: SwiftSyntax.AttributeSyntax,
229+
enumDecl: EnumDeclSyntax) throws
230+
-> [SwiftSyntax.DeclSyntax] {
231+
var caseNames = [String]()
232+
233+
for member in enumDecl.memberBlock.members {
234+
guard let caseDecl = member.decl.as(EnumCaseDeclSyntax.self) else {
235+
continue
236+
}
237+
238+
for element in caseDecl.elements {
239+
// Validation already happened in ExtensionMacro
240+
caseNames.append(element.name.text)
241+
}
242+
}
243+
244+
// Generate `static var jsonSchema: ...` computed property.
245+
let anyOfList = caseNames.map { "\"\($0)\"" }.joined(separator: ", ")
246+
let generationSchemaCode = """
247+
nonisolated static var jsonSchema: FirebaseAILogic.JSONSchema {
248+
FirebaseAILogic.JSONSchema(type: Self.self, anyOf: [\(anyOfList)])
249+
}
250+
"""
251+
252+
// Generate `var modelOutput: ...` computed property.
253+
let switchCases = caseNames.map { caseName in
254+
"case .\(caseName):\n \"\(caseName)\".modelOutput"
255+
}.joined(separator: "\n ")
256+
257+
let modelOutputCode = """
258+
nonisolated var modelOutput: FirebaseAILogic.ModelOutput {
259+
switch self {
260+
\(switchCases)
261+
}
262+
}
263+
"""
264+
265+
return [
266+
DeclSyntax(stringLiteral: generationSchemaCode),
267+
DeclSyntax(stringLiteral: modelOutputCode),
268+
]
158269
}
159270

160-
private static func generateMembers(structDescription: String?,
161-
propertyInfos: [PropertyInfo]) -> [DeclSyntax] {
271+
private static func generateStructMembers(structDescription: String?,
272+
propertyInfos: [PropertyInfo]) -> [DeclSyntax] {
162273
var propertyNames = [String]()
163274
var propertySchemas = [String]()
164275
var partiallyGeneratedProperties = [String]()

FirebaseAI/Tests/TestApp/Packages/FirebaseAILogicExtended/Tests/FirebaseGenerableMacroTests/FirebaseGenerableTests.swift

Lines changed: 106 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -48,54 +48,54 @@ final class FirebaseAILogicMacrosTests: XCTestCase {
4848
let lastName: String
4949
let age: Int
5050
51-
nonisolated static var jsonSchema: FirebaseAILogic.JSONSchema {
52-
FirebaseAILogic.JSONSchema(
53-
type: Self.self,
54-
properties: [
55-
FirebaseAILogic.JSONSchema.Property(name: "firstName", type: String.self),
56-
FirebaseAILogic.JSONSchema.Property(name: "middleName", type: String?.self),
57-
FirebaseAILogic.JSONSchema.Property(name: "lastName", type: String.self),
58-
FirebaseAILogic.JSONSchema.Property(name: "age", type: Int.self)
59-
]
60-
)
61-
}
51+
nonisolated static var jsonSchema: FirebaseAILogic.JSONSchema {
52+
FirebaseAILogic.JSONSchema(
53+
type: Self.self,
54+
properties: [
55+
FirebaseAILogic.JSONSchema.Property(name: "firstName", type: String.self),
56+
FirebaseAILogic.JSONSchema.Property(name: "middleName", type: String?.self),
57+
FirebaseAILogic.JSONSchema.Property(name: "lastName", type: String.self),
58+
FirebaseAILogic.JSONSchema.Property(name: "age", type: Int.self)
59+
]
60+
)
61+
}
6262
63-
nonisolated var modelOutput: FirebaseAILogic.ModelOutput {
64-
var properties = [(name: String, value: any ConvertibleToModelOutput)]()
65-
addProperty(name: "firstName", value: self.firstName)
66-
addProperty(name: "middleName", value: self.middleName)
67-
addProperty(name: "lastName", value: self.lastName)
68-
addProperty(name: "age", value: self.age)
69-
return ModelOutput(
70-
properties: properties,
71-
uniquingKeysWith: { _, second in
72-
second
73-
}
74-
)
75-
func addProperty(name: String, value: some FirebaseGenerable) {
76-
properties.append((name, value))
63+
nonisolated var modelOutput: FirebaseAILogic.ModelOutput {
64+
var properties = [(name: String, value: any ConvertibleToModelOutput)]()
65+
addProperty(name: "firstName", value: self.firstName)
66+
addProperty(name: "middleName", value: self.middleName)
67+
addProperty(name: "lastName", value: self.lastName)
68+
addProperty(name: "age", value: self.age)
69+
return ModelOutput(
70+
properties: properties,
71+
uniquingKeysWith: { _, second in
72+
second
7773
}
78-
func addProperty(name: String, value: (some FirebaseGenerable)?) {
79-
if let value {
80-
properties.append((name, value))
81-
}
74+
)
75+
func addProperty(name: String, value: some FirebaseGenerable) {
76+
properties.append((name, value))
77+
}
78+
func addProperty(name: String, value: (some FirebaseGenerable)?) {
79+
if let value {
80+
properties.append((name, value))
8281
}
8382
}
83+
}
8484
85-
nonisolated struct Partial: Identifiable, FirebaseAILogic.ConvertibleFromModelOutput {
86-
var id: FirebaseAILogic.ResponseID
87-
var firstName: String.Partial?
88-
var middleName: String?.Partial?
89-
var lastName: String.Partial?
90-
var age: Int.Partial?
91-
nonisolated init(_ content: FirebaseAILogic.ModelOutput) throws {
92-
self.id = content.id ?? FirebaseAILogic.ResponseID()
93-
self.firstName = try content.value(forProperty: "firstName")
94-
self.middleName = try content.value(forProperty: "middleName")
95-
self.lastName = try content.value(forProperty: "lastName")
96-
self.age = try content.value(forProperty: "age")
97-
}
85+
nonisolated struct Partial: Identifiable, FirebaseAILogic.ConvertibleFromModelOutput {
86+
var id: FirebaseAILogic.ResponseID
87+
var firstName: String.Partial?
88+
var middleName: String?.Partial?
89+
var lastName: String.Partial?
90+
var age: Int.Partial?
91+
nonisolated init(_ content: FirebaseAILogic.ModelOutput) throws {
92+
self.id = content.id ?? FirebaseAILogic.ResponseID()
93+
self.firstName = try content.value(forProperty: "firstName")
94+
self.middleName = try content.value(forProperty: "middleName")
95+
self.lastName = try content.value(forProperty: "lastName")
96+
self.age = try content.value(forProperty: "age")
9897
}
98+
}
9999
}
100100
101101
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
@@ -108,7 +108,70 @@ final class FirebaseAILogicMacrosTests: XCTestCase {
108108
}
109109
}
110110
""",
111-
macros: testMacros
111+
macros: testMacros,
112+
indentationWidth: .spaces(2)
113+
)
114+
#else
115+
throw XCTSkip("Macros are only supported when running tests for the host platform")
116+
#endif
117+
}
118+
119+
func testEnumMacro() throws {
120+
#if canImport(FirebaseAILogicMacros)
121+
assertMacroExpansion(
122+
"""
123+
@FirebaseGenerable
124+
enum Pet {
125+
case cat
126+
case dog
127+
case fish
128+
}
129+
""",
130+
expandedSource: """
131+
enum Pet {
132+
case cat
133+
case dog
134+
case fish
135+
136+
nonisolated static var jsonSchema: FirebaseAILogic.JSONSchema {
137+
FirebaseAILogic.JSONSchema(type: Self.self, anyOf: ["cat", "dog", "fish"])
138+
}
139+
140+
nonisolated var modelOutput: FirebaseAILogic.ModelOutput {
141+
switch self {
142+
case .cat:
143+
"cat".modelOutput
144+
case .dog:
145+
"dog".modelOutput
146+
case .fish:
147+
"fish".modelOutput
148+
}
149+
}
150+
}
151+
152+
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
153+
extension Pet: FirebaseAILogic.FirebaseGenerable {
154+
nonisolated init(_ content: FirebaseAILogic.ModelOutput) throws {
155+
let rawValue = try content.value(String.self)
156+
switch rawValue {
157+
case "cat":
158+
self = .cat
159+
case "dog":
160+
self = .dog
161+
case "fish":
162+
self = .fish
163+
default:
164+
throw FirebaseAILogic.GenerativeModel.GenerationError.decodingFailure(
165+
FirebaseAILogic.GenerativeModel.GenerationError.Context(
166+
debugDescription: "Unexpected value \\"\\(rawValue)\\" for \\(Self.self)"
167+
)
168+
)
169+
}
170+
}
171+
}
172+
""",
173+
macros: testMacros,
174+
indentationWidth: .spaces(2)
112175
)
113176
#else
114177
throw XCTSkip("Macros are only supported when running tests for the host platform")

FirebaseAI/Tests/TestApp/Tests/Integration/SchemaTests.swift

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -490,8 +490,14 @@ struct SchemaTests {
490490
@FirebaseGenerable
491491
struct Task {
492492
let title: String
493-
@FirebaseGuide(description: "The priority level", .anyOf(["low", "medium", "high"]))
494-
let priority: String
493+
494+
@FirebaseGuide(description: "The priority level")
495+
let priority: Priority
496+
497+
@FirebaseGenerable
498+
enum Priority {
499+
case low, medium, high
500+
}
495501
}
496502

497503
@Test(arguments: testConfigs(
@@ -514,15 +520,15 @@ struct SchemaTests {
514520
generationConfig: SchemaTests.generationConfig(schema: schema),
515521
safetySettings: safetySettings
516522
)
517-
let prompt = "Create a high priority task titled 'Fix Bug'."
523+
let prompt = "Create a medium priority task titled 'Feature Request'."
518524

519525
let response = try await model.generateContent(prompt)
520526

521527
let text = try #require(response.text)
522528
let modelOutput = try ModelOutput(json: text)
523529
let task = try Task(modelOutput)
524-
#expect(task.title == "Fix Bug")
525-
#expect(task.priority == "high")
530+
#expect(task.title == "Feature Request")
531+
#expect(task.priority == .medium)
526532
}
527533

528534
@Test(arguments: InstanceConfig.allConfigs)
@@ -538,7 +544,7 @@ struct SchemaTests {
538544

539545
let task = response.content
540546
#expect(task.title == "Fix Bug")
541-
#expect(task.priority == "high")
547+
#expect(task.priority == .high)
542548
}
543549

544550
@FirebaseGenerable

0 commit comments

Comments
 (0)