Skip to content

Commit 23b4bb8

Browse files
committed
fix: CPK uni-directional has-many lazy list load (#2730)
* fix: uni-directional has-many lazy list load * remove redundant log * rename targetNames to use associatedWithFields * rename associatedWithFields to associatedFields
1 parent a825990 commit 23b4bb8

File tree

45 files changed

+288
-249
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+288
-249
lines changed

Amplify/Categories/DataStore/Model/Internal/ModelListProvider.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@ public enum ModelListProviderState<Element: Model> {
2525
/// If the list represents an association between two models, the `associatedIdentifiers` will
2626
/// hold the information necessary to query the associated elements (e.g. comments of a post)
2727
///
28-
/// The associatedField represents the field to which the owner of the `List` is linked to.
28+
/// The associatedFields represents the field to which the owner of the `List` is linked to.
2929
/// For example, if `Post.comments` is associated with `Comment.post` the `List<Comment>`
3030
/// of `Post` will have a reference to the `post` field in `Comment`.
31-
case notLoaded(associatedIdentifiers: [String], associatedField: String)
31+
case notLoaded(associatedIdentifiers: [String], associatedFields: [String])
3232
case loaded([Element])
3333
}
3434

Amplify/Categories/DataStore/Model/Internal/Schema/ModelField+Association.swift

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ import Foundation
8787
/// - Warning: Although this has `public` access, it is intended for internal & codegen use and should not be used
8888
/// directly by host applications. The behavior of this may change without warning.
8989
public enum ModelAssociation {
90-
case hasMany(associatedFieldName: String?)
90+
case hasMany(associatedFieldName: String?, associatedFieldNames: [String] = [])
9191
case hasOne(associatedFieldName: String?, targetNames: [String])
9292
case belongsTo(associatedFieldName: String?, targetNames: [String])
9393

@@ -97,9 +97,9 @@ public enum ModelAssociation {
9797
let targetNames = targetName.map { [$0] } ?? []
9898
return .belongsTo(associatedFieldName: nil, targetNames: targetNames)
9999
}
100-
101-
public static func hasMany(associatedWith: CodingKey?) -> ModelAssociation {
102-
return .hasMany(associatedFieldName: associatedWith?.stringValue)
100+
101+
public static func hasMany(associatedWith: CodingKey? = nil, associatedFields: [CodingKey] = []) -> ModelAssociation {
102+
return .hasMany(associatedFieldName: associatedWith?.stringValue, associatedFieldNames: associatedFields.map { $0.stringValue })
103103
}
104104

105105
@available(*, deprecated, message: "Use hasOne(associatedWith:targetNames:)")
@@ -247,13 +247,9 @@ extension ModelField {
247247
if hasAssociation {
248248
let associatedModel = requiredAssociatedModelName
249249
switch association {
250-
case .belongsTo(let associatedKey, _):
251-
// TODO handle modelName casing (convert to camelCase)
252-
let key = associatedKey ?? associatedModel
253-
let schema = ModelRegistry.modelSchema(from: associatedModel)
254-
return schema?.field(withName: key)
255-
case .hasOne(let associatedKey, _),
256-
.hasMany(let associatedKey):
250+
case .belongsTo(let associatedKey, _),
251+
.hasOne(let associatedKey, _),
252+
.hasMany(let associatedKey, _):
257253
// TODO handle modelName casing (convert to camelCase)
258254
let key = associatedKey ?? associatedModel
259255
let schema = ModelRegistry.modelSchema(from: associatedModel)
@@ -265,6 +261,25 @@ extension ModelField {
265261
return nil
266262
}
267263

264+
/// - Warning: Although this has `public` access, it is intended for internal & codegen use and should not be used
265+
/// directly by host applications. The behavior of this may change without warning. Though it is not used by host
266+
/// application making any change to these `public` types should be backward compatible, otherwise it will be a
267+
/// breaking change.
268+
public var associatedFieldNames: [String] {
269+
switch association {
270+
case .belongsTo(let associatedKey, let associatedKeys),
271+
.hasOne(let associatedKey, let associatedKeys),
272+
.hasMany(let associatedKey, let associatedKeys):
273+
if associatedKeys.isEmpty, let associatedKey = associatedKey {
274+
return [associatedKey]
275+
}
276+
277+
return associatedKeys
278+
case .none:
279+
return []
280+
}
281+
}
282+
268283
/// - Warning: Although this has `public` access, it is intended for internal & codegen use and should not be used
269284
/// directly by host applications. The behavior of this may change without warning. Though it is not used by host
270285
/// application making any change to these `public` types should be backward compatible, otherwise it will be a

Amplify/Categories/DataStore/Model/Internal/Schema/ModelSchema+Definition.swift

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,18 @@ public enum ModelFieldDefinition {
240240
ofType: .collection(of: type),
241241
association: .hasMany(associatedWith: associatedKey))
242242
}
243+
244+
public static func hasMany(_ key: CodingKey,
245+
is nullability: ModelFieldNullability = .required,
246+
isReadOnly: Bool = false,
247+
ofType type: Model.Type,
248+
associatedFields associatedKeys: [CodingKey]) -> ModelFieldDefinition {
249+
return .field(key,
250+
is: nullability,
251+
isReadOnly: isReadOnly,
252+
ofType: .collection(of: type),
253+
association: .hasMany(associatedWith: associatedKeys.first ?? nil, associatedFields: associatedKeys))
254+
}
243255

244256
public static func hasOne(_ key: CodingKey,
245257
is nullability: ModelFieldNullability = .required,

AmplifyPlugins/API/Sources/AWSAPIPlugin/Core/AppSyncListDecoder.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ public struct AppSyncListDecoder: ModelListDecoder {
1616
/// Metadata that contains information about an associated parent object.
1717
struct Metadata: Codable {
1818
let appSyncAssociatedIdentifiers: [String]
19-
let appSyncAssociatedField: String
19+
let appSyncAssociatedFields: [String]
2020
let apiName: String?
2121
}
2222

AmplifyPlugins/API/Sources/AWSAPIPlugin/Core/AppSyncListProvider.swift

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@ public class AppSyncListProvider<Element: Model>: ModelListProvider {
2222
/// If the list represents an association between two models, the `associatedIdentifiers` will
2323
/// hold the information necessary to query the associated elements (e.g. comments of a post)
2424
///
25-
/// The associatedField represents the field to which the owner of the `List` is linked to.
25+
/// The associatedFields represents the field to which the owner of the `List` is linked to.
2626
/// For example, if `Post.comments` is associated with `Comment.post` the `List<Comment>`
2727
/// of `Post` will have a reference to the `post` field in `Comment`.
2828
case notLoaded(associatedIdentifiers: [String],
29-
associatedField: String)
29+
associatedFields: [String])
3030

3131
/// If the list is retrieved directly, this state holds the underlying data, nextToken used to create
3232
/// the subsequent GraphQL request, and previous filter used to create the loaded list
@@ -58,7 +58,7 @@ public class AppSyncListProvider<Element: Model>: ModelListProvider {
5858

5959
convenience init(metadata: AppSyncListDecoder.Metadata) {
6060
self.init(associatedIdentifiers: metadata.appSyncAssociatedIdentifiers,
61-
associatedField: metadata.appSyncAssociatedField,
61+
associatedFields: metadata.appSyncAssociatedFields,
6262
apiName: metadata.apiName)
6363
}
6464

@@ -76,18 +76,18 @@ public class AppSyncListProvider<Element: Model>: ModelListProvider {
7676
}
7777

7878
// Internal initializer for testing
79-
init(associatedIdentifiers: [String], associatedField: String, apiName: String? = nil) {
79+
init(associatedIdentifiers: [String], associatedFields: [String], apiName: String? = nil) {
8080
self.loadedState = .notLoaded(associatedIdentifiers: associatedIdentifiers,
81-
associatedField: associatedField)
81+
associatedFields: associatedFields)
8282
self.apiName = apiName
8383
}
8484

8585
// MARK: APIs
8686

8787
public func getState() -> ModelListProviderState<Element> {
8888
switch loadedState {
89-
case .notLoaded(let associatedIdentifiers, let associatedField):
90-
return .notLoaded(associatedIdentifiers: associatedIdentifiers, associatedField: associatedField)
89+
case .notLoaded(let associatedIdentifiers, let associatedFields):
90+
return .notLoaded(associatedIdentifiers: associatedIdentifiers, associatedFields: associatedFields)
9191
case .loaded(let elements, _, _):
9292
return .loaded(elements)
9393
}
@@ -97,22 +97,24 @@ public class AppSyncListProvider<Element: Model>: ModelListProvider {
9797
switch loadedState {
9898
case .loaded(let elements, _, _):
9999
return elements
100-
case .notLoaded(let associatedIdentifiers, let associatedField):
101-
return try await load(associatedIdentifiers: associatedIdentifiers, associatedField: associatedField)
100+
case .notLoaded(let associatedIdentifiers, let associatedFields):
101+
return try await load(associatedIdentifiers: associatedIdentifiers, associatedFields: associatedFields)
102102
}
103103
}
104104

105105
//// Internal `load` to perform the retrieval of the first page and storing it in memory
106106
func load(associatedIdentifiers: [String],
107-
associatedField: String) async throws -> [Element] {
107+
associatedFields: [String]) async throws -> [Element] {
108108
let filter: GraphQLFilter
109-
if associatedIdentifiers.count == 1, let associatedId = associatedIdentifiers.first {
109+
if associatedIdentifiers.count == 1,
110+
let associatedId = associatedIdentifiers.first,
111+
let associatedField = associatedFields.first {
110112
let predicate: QueryPredicate = field(associatedField) == associatedId
111113
filter = predicate.graphQLFilter(for: Element.schema)
112114
} else {
113115
var queryPredicates: [QueryPredicateOperation] = []
114-
let columnNames = columnNames(field: associatedField, Element.schema)
115116

117+
let columnNames = columnNames(fields: associatedFields, Element.schema)
116118
let predicateValues = zip(columnNames, associatedIdentifiers)
117119
for (identifierName, identifierValue) in predicateValues {
118120
queryPredicates.append(QueryPredicateOperation(field: identifierName,
@@ -220,10 +222,10 @@ public class AppSyncListProvider<Element: Model>: ModelListProvider {
220222

221223
public func encode(to encoder: Encoder) throws {
222224
switch loadedState {
223-
case .notLoaded(let associatedIdentifiers, let associatedField):
225+
case .notLoaded(let associatedIdentifiers, let associatedFields):
224226
let metadata = AppSyncListDecoder.Metadata.init(
225227
appSyncAssociatedIdentifiers: associatedIdentifiers,
226-
appSyncAssociatedField: associatedField,
228+
appSyncAssociatedFields: associatedFields,
227229
apiName: apiName)
228230
var container = encoder.singleValueContainer()
229231
try container.encode(metadata)
@@ -241,9 +243,14 @@ public class AppSyncListProvider<Element: Model>: ModelListProvider {
241243
// MARK: - Helpers
242244

243245
/// Retrieve the column names for the specified field `field` for this schema.
244-
func columnNames(field: String, _ modelSchema: ModelSchema) -> [String] {
245-
guard let modelField = modelSchema.field(withName: field) else {
246-
return [field]
246+
func columnNames(fields: [String], _ modelSchema: ModelSchema) -> [String] {
247+
// Associated field names have already been resolved from the parent model's has-many targetNames
248+
if fields.count > 1 {
249+
return fields
250+
}
251+
// Resolve the ModelField of the field reference
252+
guard let field = fields.first, let modelField = modelSchema.field(withName: field) else {
253+
return fields
247254
}
248255
let defaultFieldName = modelSchema.name.camelCased() + field.pascalCased() + "Id"
249256
switch modelField.association {
@@ -254,10 +261,10 @@ public class AppSyncListProvider<Element: Model>: ModelListProvider {
254261
}
255262
return targetNames
256263
default:
257-
return [field]
264+
return fields
258265
}
259266
}
260267

261268
}
262269

263-
extension AppSyncListProvider: DefaultLogger { }
270+
extension AppSyncListProvider: DefaultLogger { }

AmplifyPlugins/API/Sources/AWSAPIPlugin/Core/AppSyncModelMetadata.swift

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,9 @@ public struct AppSyncModelMetadataUtils {
113113
// Scenario: Has-many items array is missing.
114114
// Store the association data (parent's identifier and field name)
115115
// This allows the list to perform lazy loading of child items using parent identifier as the predicate
116-
if let associatedField = modelField.associatedField,
117-
modelJSON[modelField.name] == nil {
116+
if modelJSON[modelField.name] == nil {
118117
let appSyncModelMetadata = AppSyncListDecoder.Metadata(appSyncAssociatedIdentifiers: identifiers,
119-
appSyncAssociatedField: associatedField.name,
118+
appSyncAssociatedFields: modelField.associatedFieldNames,
120119
apiName: apiName)
121120
if let serializedMetadata = try? encoder.encode(appSyncModelMetadata),
122121
let metadataJSON = try? decoder.decode(JSONValue.self, from: serializedMetadata) {

AmplifyPlugins/API/Tests/APIHostApp/AWSAPIPluginLazyLoadTests/GraphQLLazyLoadBaseTest.swift

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,18 +133,18 @@ class GraphQLLazyLoadBaseTest: XCTestCase {
133133
}
134134

135135
enum AssertListState {
136-
case isNotLoaded(associatedIdentifiers: [String], associatedField: String)
136+
case isNotLoaded(associatedIdentifiers: [String], associatedFields: [String])
137137
case isLoaded(count: Int)
138138
}
139139

140140
func assertList<M: Model>(_ list: List<M>, state: AssertListState) {
141141
switch state {
142-
case .isNotLoaded(let expectedAssociatedIdentifiers, let expectedAssociatedField):
143-
if case .notLoaded(let associatedIdentifiers, let associatedField) = list.listProvider.getState() {
142+
case .isNotLoaded(let expectedAssociatedIdentifiers, let expectedAssociatedFields):
143+
if case .notLoaded(let associatedIdentifiers, let associatedFields) = list.listProvider.getState() {
144144
XCTAssertEqual(associatedIdentifiers, expectedAssociatedIdentifiers)
145-
XCTAssertEqual(associatedField, expectedAssociatedField)
145+
XCTAssertEqual(associatedFields, expectedAssociatedFields)
146146
} else {
147-
XCTFail("It should be not loaded with expected associatedId \(expectedAssociatedIdentifiers) associatedField \(expectedAssociatedField)")
147+
XCTFail("It should be not loaded with expected associatedIds \(expectedAssociatedIdentifiers) associatedFields \(expectedAssociatedFields)")
148148
}
149149
case .isLoaded(let count):
150150
if case .loaded(let loadedList) = list.listProvider.getState() {

AmplifyPlugins/API/Tests/APIHostApp/AWSAPIPluginLazyLoadTests/LL1/GraphQLLazyLoadPostComment4V2Tests.swift

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ final class GraphQLLazyLoadPostComment4V2Tests: GraphQLLazyLoadBaseTest {
3939

4040
// The loaded post should have comments that are also not loaded
4141
let comments = loadedPost.comments!
42-
assertList(comments, state: .isNotLoaded(associatedIdentifiers: [createdPost.id], associatedField: "post"))
42+
assertList(comments, state: .isNotLoaded(associatedIdentifiers: [createdPost.id], associatedFields: ["post"]))
4343
// load the comments
4444
try await comments.fetch()
4545
assertList(comments, state: .isLoaded(count: 1))
@@ -60,7 +60,7 @@ final class GraphQLLazyLoadPostComment4V2Tests: GraphQLLazyLoadBaseTest {
6060
XCTAssertEqual(loadedPost.id, post.id)
6161
// The loaded post should have comments that are not loaded
6262
let comments = loadedPost.comments!
63-
assertList(comments, state: .isNotLoaded(associatedIdentifiers: [post.id], associatedField: "post"))
63+
assertList(comments, state: .isNotLoaded(associatedIdentifiers: [post.id], associatedFields: ["post"]))
6464
// load the comments
6565
try await comments.fetch()
6666
assertList(comments, state: .isLoaded(count: 1))
@@ -175,7 +175,7 @@ final class GraphQLLazyLoadPostComment4V2Tests: GraphQLLazyLoadBaseTest {
175175
_ = try await mutate(.create(comment))
176176
let queriedPost = try await query(.get(Post.self, byId: post.id))!
177177
let comments = queriedPost.comments!
178-
assertList(comments, state: .isNotLoaded(associatedIdentifiers: [post.id], associatedField: "post"))
178+
assertList(comments, state: .isNotLoaded(associatedIdentifiers: [post.id], associatedFields: ["post"]))
179179
try await comments.fetch()
180180
assertList(comments, state: .isLoaded(count: 1))
181181
assertLazyReference(comments.first!._post, state: .notLoaded(identifiers: [.init(name: "id", value: post.id)]))
@@ -217,7 +217,7 @@ final class GraphQLLazyLoadPostComment4V2Tests: GraphQLLazyLoadBaseTest {
217217
let queriedPosts = try await listQuery(.list(Post.self, where: Post.keys.id == post.id))
218218
assertList(queriedPosts, state: .isLoaded(count: 1))
219219
assertList(queriedPosts.first!.comments!,
220-
state: .isNotLoaded(associatedIdentifiers: [post.id], associatedField: "post"))
220+
state: .isNotLoaded(associatedIdentifiers: [post.id], associatedFields: ["post"]))
221221

222222
let queriedComments = try await listQuery(.list(Comment.self, where: Comment.keys.id == comment.id))
223223
assertList(queriedComments, state: .isLoaded(count: 1))
@@ -394,7 +394,7 @@ final class GraphQLLazyLoadPostComment4V2Tests: GraphQLLazyLoadBaseTest {
394394
switch result {
395395
case .success(let createdPost):
396396
log.verbose("Successfully got createdPost from subscription: \(createdPost)")
397-
assertList(createdPost.comments!, state: .isNotLoaded(associatedIdentifiers: [post.id], associatedField: "post"))
397+
assertList(createdPost.comments!, state: .isNotLoaded(associatedIdentifiers: [post.id], associatedFields: ["post"]))
398398
await onCreatedPost.fulfill()
399399
case .failure(let error):
400400
XCTFail("Got failed result with \(error.errorDescription)")

0 commit comments

Comments
 (0)