Skip to content

Commit c4bef13

Browse files
authored
Support RETURNING clause for generated columns (#111)
* Support `RETURNING` clause for generated columns * wip * wip * wip * wip * wip
1 parent 53dd31b commit c4bef13

File tree

12 files changed

+408
-29
lines changed

12 files changed

+408
-29
lines changed

Sources/StructuredQueries/Macros.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ public macro Table(
4141
public macro Column(
4242
_ name: String = "",
4343
as representableType: (any QueryRepresentable.Type)? = nil,
44-
generated: GeneratedColumn? = nil,
44+
generated: GeneratedColumnStorage? = nil,
4545
primaryKey: Bool = false
4646
) =
4747
#externalMacro(

Sources/StructuredQueriesCore/Never.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ extension Never: Table {
33
public typealias QueryValue = Never
44

55
public static var allColumns: [any TableColumnExpression] { [] }
6+
7+
public static var writableColumns: [any WritableTableColumnExpression] { [] }
68
}
79

810
public static var columns: TableColumns {

Sources/StructuredQueriesCore/Optional.swift

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@ extension Optional: Table where Wrapped: Table {
9494
Wrapped.TableColumns.allColumns
9595
}
9696

97+
public static var writableColumns: [any WritableTableColumnExpression] {
98+
Wrapped.TableColumns.writableColumns
99+
}
100+
97101
public subscript<Member>(
98102
dynamicMember keyPath: KeyPath<Wrapped.TableColumns, TableColumn<Wrapped, Member>>
99103
) -> TableColumn<Optional, Member?> {

Sources/StructuredQueriesCore/Statements/Insert.swift

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,10 @@ extension Table {
118118
var valueFragments: [[QueryFragment]] = []
119119
for value in values() {
120120
var valueFragment: [QueryFragment] = []
121-
for column in TableColumns.allColumns {
122-
func open<Root, Value>(_ column: some TableColumnExpression<Root, Value>) -> QueryFragment {
121+
for column in TableColumns.writableColumns {
122+
func open<Root, Value>(
123+
_ column: some WritableTableColumnExpression<Root, Value>
124+
) -> QueryFragment {
123125
Value(queryOutput: (value as! Root)[keyPath: column.keyPath]).queryFragment
124126
}
125127
valueFragment.append(open(column))
@@ -128,7 +130,7 @@ extension Table {
128130
}
129131
return _insert(
130132
or: conflictResolution,
131-
columnNames: TableColumns.allColumns.map(\.name),
133+
columnNames: TableColumns.writableColumns.map(\.name),
132134
values: .values(valueFragments),
133135
onConflict: conflictTargets,
134136
where: targetFilter,
@@ -544,7 +546,8 @@ extension PrimaryKeyedTable {
544546
values: values,
545547
onConflict: { $0.primaryKey },
546548
doUpdate: { updates in
547-
for column in Draft.TableColumns.allColumns where column.name != columns.primaryKey.name {
549+
for column in Draft.TableColumns.writableColumns
550+
where column.name != columns.primaryKey.name {
548551
updates.set(column, #""excluded".\#(quote: column.name)"#)
549552
}
550553
}
@@ -564,8 +567,10 @@ extension PrimaryKeyedTable {
564567
var valueFragments: [[QueryFragment]] = []
565568
for value in values() {
566569
var valueFragment: [QueryFragment] = []
567-
for column in Draft.TableColumns.allColumns {
568-
func open<Root, Value>(_ column: some TableColumnExpression<Root, Value>) -> QueryFragment {
570+
for column in Draft.TableColumns.writableColumns {
571+
func open<Root, Value>(
572+
_ column: some WritableTableColumnExpression<Root, Value>
573+
) -> QueryFragment {
569574
Value(queryOutput: (value as! Root)[keyPath: column.keyPath]).queryFragment
570575
}
571576
valueFragment.append(open(column))
@@ -574,7 +579,7 @@ extension PrimaryKeyedTable {
574579
}
575580
return _insert(
576581
or: conflictResolution,
577-
columnNames: Draft.TableColumns.allColumns.map(\.name),
582+
columnNames: Draft.TableColumns.writableColumns.map(\.name),
578583
values: .values(valueFragments),
579584
onConflict: conflictTargets,
580585
where: targetFilter,

Sources/StructuredQueriesCore/TableAlias.swift

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,19 @@ public struct TableAlias<
125125
#endif
126126
}
127127

128+
public static var writableColumns: [any WritableTableColumnExpression] {
129+
#if compiler(>=6.1)
130+
return Base.TableColumns.writableColumns.map { $0._aliased(Name.self) }
131+
#else
132+
func open(
133+
_ column: some WritableTableColumnExpression
134+
) -> any WritableTableColumnExpression {
135+
column._aliased(Name.self)
136+
}
137+
return Base.TableColumns.writableColumns.map { open($0) }
138+
#endif
139+
}
140+
128141
public typealias QueryValue = TableAlias
129142

130143
public subscript<Member>(
@@ -136,6 +149,16 @@ public struct TableAlias<
136149
keyPath: \.[member: \Member.self, column: column._keyPath]
137150
)
138151
}
152+
153+
public subscript<Member>(
154+
dynamicMember keyPath: KeyPath<Base.TableColumns, GeneratedColumn<Base, Member>>
155+
) -> GeneratedColumn<TableAlias, Member> {
156+
let column = Base.columns[keyPath: keyPath]
157+
return GeneratedColumn<TableAlias, Member>(
158+
column.name,
159+
keyPath: \.[member: \Member.self, column: column._keyPath]
160+
)
161+
}
139162
}
140163
}
141164

Sources/StructuredQueriesCore/TableColumn.swift

Lines changed: 77 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,27 @@ public protocol TableColumnExpression<Root, Value>: QueryExpression where Value
2020
) -> any TableColumnExpression<TableAlias<Root, Name>, Value>
2121
}
2222

23+
/// A type representing a _writable_ table column, _i.e._ not a generated column.
24+
public protocol WritableTableColumnExpression<Root, Value>: TableColumnExpression {
25+
func _aliased<Name: AliasName>(
26+
_ alias: Name.Type
27+
) -> any WritableTableColumnExpression<TableAlias<Root, Name>, Value>
28+
}
29+
30+
extension WritableTableColumnExpression {
31+
public func _aliased<Name: AliasName>(
32+
_ alias: Name.Type
33+
) -> any TableColumnExpression<TableAlias<Root, Name>, Value> {
34+
_aliased(alias)
35+
}
36+
}
37+
2338
/// A type representing a table column.
2439
///
2540
/// Don't create instances of this value directly. Instead, use the `@Table` and `@Column` macros to
2641
/// generate values of this type.
2742
public struct TableColumn<Root: Table, Value: QueryRepresentable & QueryBindable>:
28-
TableColumnExpression,
43+
WritableTableColumnExpression,
2944
Sendable
3045
where Value.QueryOutput: Sendable {
3146
public typealias QueryValue = Value
@@ -70,22 +85,80 @@ where Value.QueryOutput: Sendable {
7085

7186
public func _aliased<Name>(
7287
_ alias: Name.Type
73-
) -> any TableColumnExpression<TableAlias<Root, Name>, Value> {
88+
) -> any WritableTableColumnExpression<TableAlias<Root, Name>, Value> {
7489
TableColumn<TableAlias<Root, Name>, Value>(
7590
name,
7691
keyPath: \.[member: \Value.self, column: _keyPath]
7792
)
7893
}
7994
}
8095

81-
/// A type that describes how a table column is generated (e.g. SQLite generated columns).
96+
/// A type that describes how a table column is generated (_e.g._, SQLite generated columns).
8297
///
8398
/// You provide a value of this type to a `@Column` macro to differentiate between generated columns
8499
/// that are physically stored in the database table and those that are "virtual".
85100
///
86101
/// ```swift
87102
/// @Column(generated: .stored)
88103
/// ```
89-
public enum GeneratedColumn {
104+
public enum GeneratedColumnStorage {
90105
case virtual, stored
91106
}
107+
108+
/// A type representing a generated column.
109+
///
110+
/// Don't create instances of this value directly. Instead, use the `@Table` and `@Column` macros to
111+
/// generate values of this type.
112+
public struct GeneratedColumn<Root: Table, Value: QueryRepresentable & QueryBindable>:
113+
TableColumnExpression,
114+
Sendable
115+
where Value.QueryOutput: Sendable {
116+
public typealias QueryValue = Value
117+
118+
public let name: String
119+
120+
public let defaultValue: Value.QueryOutput?
121+
122+
let _keyPath: KeyPath<Root, Value.QueryOutput> & Sendable
123+
124+
public var keyPath: KeyPath<Root, Value.QueryOutput> {
125+
_keyPath
126+
}
127+
128+
public init(
129+
_ name: String,
130+
keyPath: KeyPath<Root, Value.QueryOutput> & Sendable,
131+
default defaultValue: Value.QueryOutput? = nil
132+
) {
133+
self.name = name
134+
self.defaultValue = defaultValue
135+
self._keyPath = keyPath
136+
}
137+
138+
public init(
139+
_ name: String,
140+
keyPath: KeyPath<Root, Value.QueryOutput> & Sendable,
141+
default defaultValue: Value? = nil
142+
) where Value == Value.QueryOutput {
143+
self.name = name
144+
self.defaultValue = defaultValue
145+
self._keyPath = keyPath
146+
}
147+
148+
public func decode(_ decoder: inout some QueryDecoder) throws -> Value.QueryOutput {
149+
try Value(decoder: &decoder).queryOutput
150+
}
151+
152+
public var queryFragment: QueryFragment {
153+
"\(Root.self).\(quote: name)"
154+
}
155+
156+
public func _aliased<Name>(
157+
_ alias: Name.Type
158+
) -> any TableColumnExpression<TableAlias<Root, Name>, Value> {
159+
TableColumn<TableAlias<Root, Name>, Value>(
160+
name,
161+
keyPath: \.[member: \Value.self, column: _keyPath]
162+
)
163+
}
164+
}

Sources/StructuredQueriesCore/TableDefinition.swift

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
public protocol TableDefinition<QueryValue>: QueryExpression where QueryValue: Table {
77
/// An array of this table's columns.
88
static var allColumns: [any TableColumnExpression] { get }
9+
10+
/// An array of this table's writable (non-generated) columns.
11+
static var writableColumns: [any WritableTableColumnExpression] { get }
912
}
1013

1114
extension TableDefinition {

Sources/StructuredQueriesCore/Updates.swift

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,10 @@ public struct Updates<Base: Table> {
4040

4141
@_disfavoredOverload
4242
public subscript<Value: QueryExpression>(
43-
dynamicMember keyPath: KeyPath<Base.TableColumns, some TableColumnExpression<Base, Value>>
43+
dynamicMember keyPath: KeyPath<
44+
Base.TableColumns,
45+
some WritableTableColumnExpression<Base, Value>
46+
>
4447
) -> Value.QueryOutput {
4548
@available(*, unavailable)
4649
get { fatalError() }

Sources/StructuredQueriesMacros/TableMacro.swift

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,25 @@ extension TableMacro: ExtensionMacro {
216216
else {
217217
continue
218218
}
219+
guard property.bindingSpecifier.tokenKind == .keyword(.let)
220+
else {
221+
diagnostics.append(
222+
Diagnostic(
223+
node: property.bindingSpecifier,
224+
message: MacroExpansionErrorMessage(
225+
"Generated column property must be declared with a 'let'"
226+
),
227+
fixIt: .replace(
228+
message: MacroExpansionFixItMessage("Replace 'var' with 'let'"),
229+
oldNode: Syntax(property.bindingSpecifier),
230+
newNode: Syntax(
231+
property.bindingSpecifier.with(\.tokenKind, .keyword(.let))
232+
)
233+
)
234+
)
235+
)
236+
continue
237+
}
219238
isGenerated = true
220239

221240
case let argument?:
@@ -256,13 +275,16 @@ extension TableMacro: ExtensionMacro {
256275
if isGenerated {
257276
columnsProperties.append(
258277
"""
259-
public var \(identifier): some \(moduleName).QueryExpression<\(raw: columnQueryValueType?.trimmedDescription ?? "_")> {
260-
\(moduleName).TableColumn<\
261-
QueryValue, \
262-
\(columnQueryValueType?.rewritten(selfRewriter) ?? "_")\
263-
>(\
264-
\(columnName), \
265-
keyPath: \\QueryValue.\(identifier))
278+
public var \(identifier): \(moduleName).GeneratedColumn<\
279+
QueryValue, \
280+
\(raw: columnQueryValueType?.trimmedDescription ?? "_")\
281+
> {
282+
\(moduleName).GeneratedColumn<\
283+
QueryValue, \
284+
\(columnQueryValueType?.rewritten(selfRewriter) ?? "_")\
285+
>(\
286+
\(columnName), \
287+
keyPath: \\QueryValue.\(identifier))
266288
}
267289
"""
268290
)
@@ -603,6 +625,7 @@ extension TableMacro: MemberMacro {
603625
}
604626
let type = IdentifierTypeSyntax(name: declaration.name.trimmed)
605627
var allColumns: [TokenSyntax] = []
628+
var writableColumns: [TokenSyntax] = []
606629
var selectedColumns: [TokenSyntax] = []
607630
var columnsProperties: [DeclSyntax] = []
608631
var decodings: [String] = []
@@ -748,8 +771,11 @@ extension TableMacro: MemberMacro {
748771
if isGenerated {
749772
columnsProperties.append(
750773
"""
751-
public var \(identifier): some \(moduleName).QueryExpression<\(raw: columnQueryValueType?.trimmedDescription ?? "_")> { \
752-
\(moduleName).TableColumn<\
774+
public var \(identifier): \(moduleName).GeneratedColumn<\
775+
QueryValue, \
776+
\(raw: columnQueryValueType?.trimmedDescription ?? "_")\
777+
> { \
778+
\(moduleName).GeneratedColumn<\
753779
QueryValue, \
754780
\(columnQueryValueType?.rewritten(selfRewriter) ?? "_")\
755781
>(\
@@ -759,6 +785,7 @@ extension TableMacro: MemberMacro {
759785
}
760786
"""
761787
)
788+
allColumns.append(identifier)
762789
} else {
763790
columnsProperties.append(
764791
"""
@@ -772,6 +799,7 @@ extension TableMacro: MemberMacro {
772799
"""
773800
)
774801
allColumns.append(identifier)
802+
writableColumns.append(identifier)
775803
}
776804
let decodedType = columnQueryValueType?.asNonOptionalType()
777805
if let defaultValue {
@@ -1007,6 +1035,9 @@ extension TableMacro: MemberMacro {
10071035
public static var allColumns: [any \(moduleName).TableColumnExpression] { \
10081036
[\(allColumns.map { "QueryValue.columns.\($0)" as ExprSyntax }, separator: ", ")]
10091037
}
1038+
public static var writableColumns: [any \(moduleName).WritableTableColumnExpression] { \
1039+
[\(writableColumns.map { "QueryValue.columns.\($0)" as ExprSyntax }, separator: ", ")]
1040+
}
10101041
public var queryFragment: QueryFragment {
10111042
"\(selectedColumns.map { #"\(self.\#($0))"# as ExprSyntax }, separator: ", ")"
10121043
}

0 commit comments

Comments
 (0)