diff --git a/packages/datastore/__tests__/subscription-variables-edge-cases.test.ts b/packages/datastore/__tests__/subscription-variables-edge-cases.test.ts new file mode 100644 index 00000000000..31f129ace20 --- /dev/null +++ b/packages/datastore/__tests__/subscription-variables-edge-cases.test.ts @@ -0,0 +1,239 @@ +import { SubscriptionProcessor } from '../src/sync/processors/subscription'; +import { TransformerMutationType, processSubscriptionVariables } from '../src/sync/utils'; +import { SchemaModel, InternalSchema } from '../src/types'; + +describe('Subscription Variables - Edge Cases & Safety', () => { + let mockGraphQL: jest.Mock; + + beforeEach(() => { + mockGraphQL = jest.fn(); + jest.clearAllMocks(); + }); + + const createTestSchema = (): InternalSchema => ({ + namespaces: { + user: { + name: 'user', + models: { + Todo: { + name: 'Todo', + pluralName: 'Todos', + syncable: true, + attributes: [], + fields: { + id: { + name: 'id', + type: 'ID', + isRequired: true, + isArray: false, + }, + }, + }, + }, + relationships: {}, + enums: {}, + nonModels: {}, + }, + }, + version: '1', + codegenVersion: '3.0.0', + }); + + describe('Mutation Protection', () => { + it('should not allow mutations to affect cached values', () => { + const schema = createTestSchema(); + const sharedObject = { storeId: 'initial' }; + + const cache = new WeakMap(); + const result1 = processSubscriptionVariables( + schema.namespaces.user.models.Todo, + TransformerMutationType.CREATE, + sharedObject, + cache, + ); + + sharedObject.storeId = 'mutated'; + + const result2 = processSubscriptionVariables( + schema.namespaces.user.models.Todo, + TransformerMutationType.CREATE, + sharedObject, + cache, + ); + + expect(result1).toEqual(result2); + expect(result2?.storeId).not.toBe('mutated'); + }); + + it('should handle circular references gracefully', () => { + const schema = createTestSchema(); + const circular: any = { storeId: 'test' }; + circular.self = circular; + + const cache = new WeakMap(); + const result = processSubscriptionVariables( + schema.namespaces.user.models.Todo, + TransformerMutationType.CREATE, + circular, + cache, + ); + + expect(result).toBeDefined(); + expect(result?.storeId).toBe('test'); + }); + }); + + describe('Invalid Input Handling', () => { + it('should reject non-object static variables', () => { + const schema = createTestSchema(); + + const testCases = [ + { value: 'string', desc: 'string' }, + { value: 123, desc: 'number' }, + { value: true, desc: 'boolean' }, + { value: ['array'], desc: 'array' }, + ]; + + testCases.forEach(({ value, desc }) => { + const cache = new WeakMap(); + const result = processSubscriptionVariables( + schema.namespaces.user.models.Todo, + TransformerMutationType.CREATE, + value as any, + cache, + ); + + expect(result).toBeUndefined(); + }); + }); + + it('should handle Object.create(null) objects', () => { + const schema = createTestSchema(); + const nullProtoObj = Object.create(null); + nullProtoObj.storeId = 'test'; + + const cache = new WeakMap(); + const result = processSubscriptionVariables( + schema.namespaces.user.models.Todo, + TransformerMutationType.CREATE, + nullProtoObj, + cache, + ); + + expect(result).toBeDefined(); + expect(result?.storeId).toBe('test'); + }); + + it('should handle function that throws', () => { + const schema = createTestSchema(); + + const cache = new WeakMap(); + const mockFn = () => { + throw new Error('Function error'); + }; + const result = processSubscriptionVariables( + schema.namespaces.user.models.Todo, + TransformerMutationType.CREATE, + mockFn, + cache, + ); + + expect(result).toBeUndefined(); + }); + + it('should handle function returning non-object', () => { + const schema = createTestSchema(); + + const testCases = [ + { value: null, desc: 'null' }, + { value: undefined, desc: 'undefined' }, + { value: 'string', desc: 'string' }, + { value: 123, desc: 'number' }, + { value: ['array'], desc: 'array' }, + ]; + + testCases.forEach(({ value, desc }) => { + const cache = new WeakMap(); + const mockFn = () => value; + const result = processSubscriptionVariables( + schema.namespaces.user.models.Todo, + TransformerMutationType.CREATE, + mockFn as any, + cache, + ); + + expect(result).toBeUndefined(); + }); + }); + }); + + describe('Cache Behavior', () => { + it('should only call function once per operation', () => { + const schema = createTestSchema(); + const mockFn = jest.fn(() => ({ storeId: 'test' })); + + const cache = new WeakMap(); + for (let i = 0; i < 5; i++) { + processSubscriptionVariables( + schema.namespaces.user.models.Todo, + TransformerMutationType.CREATE, + mockFn, + cache, + ); + } + + expect(mockFn).toHaveBeenCalledTimes(1); + expect(mockFn).toHaveBeenCalledWith(TransformerMutationType.CREATE); + + processSubscriptionVariables( + schema.namespaces.user.models.Todo, + TransformerMutationType.UPDATE, + mockFn, + cache, + ); + + expect(mockFn).toHaveBeenCalledTimes(2); + expect(mockFn).toHaveBeenCalledWith(TransformerMutationType.UPDATE); + }); + + it('should clear cache on stop', async () => { + const schema = createTestSchema(); + const mockFn = jest.fn(() => ({ storeId: 'test' })); + + const processor = new SubscriptionProcessor( + schema, + new WeakMap(), + {}, + 'DEFAULT' as any, + jest.fn(), + { InternalAPI: { graphql: mockGraphQL } } as any, + { + subscriptionVariables: { + Todo: mockFn, + }, + }, + ); + + let cache = new WeakMap(); + processSubscriptionVariables( + schema.namespaces.user.models.Todo, + TransformerMutationType.CREATE, + mockFn, + cache, + ); + expect(mockFn).toHaveBeenCalledTimes(1); + + await processor.stop(); + cache = new WeakMap(); + + processSubscriptionVariables( + schema.namespaces.user.models.Todo, + TransformerMutationType.CREATE, + mockFn, + cache, + ); + + expect(mockFn).toHaveBeenCalledTimes(2); + }); + }); +}); \ No newline at end of file diff --git a/packages/datastore/__tests__/subscription-variables.test.ts b/packages/datastore/__tests__/subscription-variables.test.ts new file mode 100644 index 00000000000..64ce81a50cc --- /dev/null +++ b/packages/datastore/__tests__/subscription-variables.test.ts @@ -0,0 +1,298 @@ +import { Observable } from 'rxjs'; +import { SubscriptionProcessor } from '../src/sync/processors/subscription'; +import { TransformerMutationType } from '../src/sync/utils'; +import { SchemaModel, InternalSchema } from '../src/types'; +import { buildSubscriptionGraphQLOperation } from '../src/sync/utils'; + +describe('DataStore Subscription Variables', () => { + let mockObservable: Observable; + let mockGraphQL: jest.Mock; + + beforeEach(() => { + mockObservable = new Observable(() => {}); + mockGraphQL = jest.fn(() => mockObservable); + }); + + describe('buildSubscriptionGraphQLOperation', () => { + it('should include custom variables in subscription query', () => { + const namespace: any = { + name: 'user', + models: {}, + relationships: {}, + enums: {}, + nonModels: {}, + }; + + const modelDefinition: SchemaModel = { + name: 'Todo', + pluralName: 'Todos', + syncable: true, + attributes: [], + fields: { + id: { + name: 'id', + type: 'ID', + isRequired: true, + isArray: false, + }, + title: { + name: 'title', + type: 'String', + isRequired: false, + isArray: false, + }, + storeId: { + name: 'storeId', + type: 'String', + isRequired: false, + isArray: false, + }, + }, + }; + + const customVariables = { + storeId: 'store123', + tenantId: 'tenant456', + }; + + const [opType, opName, query] = buildSubscriptionGraphQLOperation( + namespace, + modelDefinition, + TransformerMutationType.CREATE, + false, + '', + false, + customVariables, + ); + + expect(opType).toBe(TransformerMutationType.CREATE); + expect(opName).toBe('onCreateTodo'); + expect(query).toContain('$storeId: String'); + expect(query).toContain('$tenantId: String'); + expect(query).toContain('storeId: $storeId'); + expect(query).toContain('tenantId: $tenantId'); + }); + + it('should work without custom variables', () => { + const namespace: any = { + name: 'user', + models: {}, + relationships: {}, + enums: {}, + nonModels: {}, + }; + + const modelDefinition: SchemaModel = { + name: 'Todo', + pluralName: 'Todos', + syncable: true, + attributes: [], + fields: { + id: { + name: 'id', + type: 'ID', + isRequired: true, + isArray: false, + }, + title: { + name: 'title', + type: 'String', + isRequired: false, + isArray: false, + }, + }, + }; + + const [opType, opName, query] = buildSubscriptionGraphQLOperation( + namespace, + modelDefinition, + TransformerMutationType.CREATE, + false, + '', + false, + ); + + expect(opType).toBe(TransformerMutationType.CREATE); + expect(opName).toBe('onCreateTodo'); + expect(query).not.toContain('$storeId'); + expect(query).not.toContain('$tenantId'); + }); + }); + + describe('SubscriptionProcessor with custom variables', () => { + it('should use custom variables from config when building subscriptions', () => { + const schema: InternalSchema = { + namespaces: { + user: { + name: 'user', + models: { + Todo: { + name: 'Todo', + pluralName: 'Todos', + syncable: true, + attributes: [], + fields: { + id: { + name: 'id', + type: 'ID', + isRequired: true, + isArray: false, + }, + title: { + name: 'title', + type: 'String', + isRequired: false, + isArray: false, + }, + storeId: { + name: 'storeId', + type: 'String', + isRequired: false, + isArray: false, + }, + }, + }, + }, + relationships: {}, + enums: {}, + nonModels: {}, + }, + }, + version: '1', + codegenVersion: '3.0.0', + }; + + const syncPredicates = new WeakMap(); + const datastoreConfig = { + subscriptionVariables: { + Todo: { + storeId: 'store123', + }, + }, + }; + + const processor = new SubscriptionProcessor( + schema, + syncPredicates, + {}, + 'DEFAULT' as any, + jest.fn(), + { InternalAPI: { graphql: mockGraphQL } } as any, + datastoreConfig, + ); + + // @ts-ignore - accessing private method for testing + const result = processor.buildSubscription( + schema.namespaces.user, + schema.namespaces.user.models.Todo, + TransformerMutationType.CREATE, + 0, + undefined, + 'userPool', + false, + ); + + expect(result.opName).toBe('onCreateTodo'); + expect(result.query).toContain('$storeId: String'); + expect(result.query).toContain('storeId: $storeId'); + }); + + it('should support function-based subscription variables', () => { + const schema: InternalSchema = { + namespaces: { + user: { + name: 'user', + models: { + Todo: { + name: 'Todo', + pluralName: 'Todos', + syncable: true, + attributes: [], + fields: { + id: { + name: 'id', + type: 'ID', + isRequired: true, + isArray: false, + }, + title: { + name: 'title', + type: 'String', + isRequired: false, + isArray: false, + }, + storeId: { + name: 'storeId', + type: 'String', + isRequired: false, + isArray: false, + }, + }, + }, + }, + relationships: {}, + enums: {}, + nonModels: {}, + }, + }, + version: '1', + codegenVersion: '3.0.0', + }; + + const syncPredicates = new WeakMap(); + const datastoreConfig = { + subscriptionVariables: { + Todo: (operation: string) => { + if (operation === TransformerMutationType.CREATE) { + return { storeId: 'store-create' }; + } + if (operation === TransformerMutationType.UPDATE) { + return { storeId: 'store-update' }; + } + return { storeId: 'store-delete' }; + }, + }, + }; + + const processor = new SubscriptionProcessor( + schema, + syncPredicates, + {}, + 'DEFAULT' as any, + jest.fn(), + { InternalAPI: { graphql: mockGraphQL } } as any, + datastoreConfig, + ); + + // Test CREATE operation + // @ts-ignore - accessing private method for testing + const createResult = processor.buildSubscription( + schema.namespaces.user, + schema.namespaces.user.models.Todo, + TransformerMutationType.CREATE, + 0, + undefined, + 'userPool', + false, + ); + + expect(createResult.query).toContain('$storeId: String'); + expect(createResult.query).toContain('storeId: $storeId'); + + // Test UPDATE operation + // @ts-ignore - accessing private method for testing + const updateResult = processor.buildSubscription( + schema.namespaces.user, + schema.namespaces.user.models.Todo, + TransformerMutationType.UPDATE, + 0, + undefined, + 'userPool', + false, + ); + + expect(updateResult.query).toContain('$storeId: String'); + expect(updateResult.query).toContain('storeId: $storeId'); + }); + }); +}); \ No newline at end of file diff --git a/packages/datastore/src/datastore/datastore.ts b/packages/datastore/src/datastore/datastore.ts index b2ba15eb6b5..c45cf94fbea 100644 --- a/packages/datastore/src/datastore/datastore.ts +++ b/packages/datastore/src/datastore/datastore.ts @@ -1406,6 +1406,7 @@ class DataStore { // sync engine processors, storage engine, adapters, etc.. private amplifyConfig: Record = {}; + private datastoreConfig: DataStoreConfig = {}; private authModeStrategy!: AuthModeStrategy; private conflictHandler!: ConflictHandler; private errorHandler!: (error: SyncError) => void; @@ -1566,6 +1567,7 @@ class DataStore { this.authModeStrategy, this.amplifyContext, this.connectivityMonitor, + this.datastoreConfig, ); const fullSyncIntervalInMilliseconds = @@ -2458,6 +2460,7 @@ class DataStore { configure = (config: DataStoreConfig = {}) => { this.amplifyContext.InternalAPI = this.InternalAPI; + this.datastoreConfig = config; const { DataStore: configDataStore, diff --git a/packages/datastore/src/sync/index.ts b/packages/datastore/src/sync/index.ts index 3575caab2a8..992c25dc969 100644 --- a/packages/datastore/src/sync/index.ts +++ b/packages/datastore/src/sync/index.ts @@ -154,6 +154,7 @@ export class SyncEngine { private readonly authModeStrategy: AuthModeStrategy, private readonly amplifyContext: AmplifyContext, private readonly connectivityMonitor?: DataStoreConnectivity, + private readonly datastoreConfig?: Record, ) { this.runningProcesses = new BackgroundProcessManager(); this.waitForSleepState = new Promise(resolve => { @@ -188,6 +189,7 @@ export class SyncEngine { this.authModeStrategy, errorHandler, this.amplifyContext, + this.datastoreConfig, ); this.mutationsProcessor = new MutationProcessor( diff --git a/packages/datastore/src/sync/processors/subscription.ts b/packages/datastore/src/sync/processors/subscription.ts index c508c8d5885..c6934b3e8d3 100644 --- a/packages/datastore/src/sync/processors/subscription.ts +++ b/packages/datastore/src/sync/processors/subscription.ts @@ -41,6 +41,7 @@ import { getTokenForCustomAuth, getUserGroupsFromToken, predicateToGraphQLFilter, + processSubscriptionVariables, } from '../utils'; import { ModelPredicateCreator } from '../../predicates'; import { validatePredicate } from '../../util'; @@ -75,6 +76,12 @@ class SubscriptionProcessor { private buffer: [TransformerMutationType, SchemaModel, PersistentModel][] = []; + // Cache for subscription variables to avoid repeated function calls + private variablesCache = new WeakMap< + SchemaModel, + Map | null> + >(); + private dataObserver!: Observer; private runningProcesses = new BackgroundProcessManager(); @@ -91,6 +98,7 @@ class SubscriptionProcessor { private readonly amplifyContext: AmplifyContext = { InternalAPI, }, + private readonly datastoreConfig?: Record, ) {} private buildSubscription( @@ -120,6 +128,16 @@ class SubscriptionProcessor { authMode, ) || {}; + // Get custom subscription variables from DataStore config + const customVariables = this.datastoreConfig?.subscriptionVariables + ? processSubscriptionVariables( + model, + transformerMutationType, + this.datastoreConfig.subscriptionVariables[model.name], + this.variablesCache, + ) + : undefined; + const [opType, opName, query] = buildSubscriptionGraphQLOperation( namespace, model, @@ -127,6 +145,7 @@ class SubscriptionProcessor { isOwner, ownerField!, filterArg, + customVariables, ); return { authMode, opType, opName, query, isOwner, ownerField, ownerValue }; @@ -369,6 +388,68 @@ class SubscriptionProcessor { action: DataStoreAction.Subscribe, }; + // Add custom subscription variables from DataStore config + const customVars = this.datastoreConfig + ?.subscriptionVariables + ? processSubscriptionVariables( + modelDefinition, + operation, + this.datastoreConfig.subscriptionVariables[ + modelDefinition.name + ], + this.variablesCache, + ) + : undefined; + + if (customVars) { + // Check for reserved keys that would conflict + const reservedKeys = [ + 'filter', + 'owner', + 'limit', + 'nextToken', + 'sortDirection', + ]; + + const safeVars: Record = {}; + let hasConflicts = false; + + // Safe iteration that handles Object.create(null) + try { + for (const [key, value] of Object.entries(customVars)) { + if (reservedKeys.includes(key)) { + hasConflicts = true; + } else { + safeVars[key] = value; + } + } + } catch (entriesError) { + // Fallback for objects without prototype + for (const key in customVars) { + if ( + Object.prototype.hasOwnProperty.call( + customVars, + key, + ) + ) { + if (reservedKeys.includes(key)) { + hasConflicts = true; + } else { + safeVars[key] = customVars[key]; + } + } + } + } + + if (hasConflicts) { + logger.warn( + `subscriptionVariables for ${modelDefinition.name} contains reserved keys that were filtered out`, + ); + } + + Object.assign(variables, safeVars); + } + if (addFilter && predicatesGroup) { (variables as any).filter = predicateToGraphQLFilter(predicatesGroup); @@ -657,6 +738,8 @@ class SubscriptionProcessor { public async stop() { await this.runningProcesses.close(); await this.runningProcesses.open(); + // Clear cache on stop + this.variablesCache = new WeakMap(); } private passesPredicateValidation( diff --git a/packages/datastore/src/sync/utils.ts b/packages/datastore/src/sync/utils.ts index 23830060f58..48c6a0dd46e 100644 --- a/packages/datastore/src/sync/utils.ts +++ b/packages/datastore/src/sync/utils.ts @@ -318,6 +318,7 @@ export function buildSubscriptionGraphQLOperation( isOwnerAuthorization: boolean, ownerField: string, filterArg = false, + customVariables?: Record, ): [TransformerMutationType, string, string] { const selectionSet = generateSelectionSet(namespace, modelDefinition); @@ -338,6 +339,33 @@ export function buildSubscriptionGraphQLOperation( opArgs.push(`${ownerField}: $${ownerField}`); } + if (customVariables) { + const VALID_VAR_NAME = /^[_a-zA-Z][_a-zA-Z0-9]*$/; + + Object.keys(customVariables).forEach(varName => { + if (!VALID_VAR_NAME.test(varName)) { + logger.warn( + `Invalid GraphQL variable name '${varName}' in subscriptionVariables. Skipping.`, + ); + + return; + } + + if ( + customVariables[varName] === null || + customVariables[varName] === undefined + ) { + return; + } + + const varType = Array.isArray(customVariables[varName]) + ? '[String]' + : 'String'; + docArgs.push(`$${varName}: ${varType}`); + opArgs.push(`${varName}: $${varName}`); + }); + } + const docStr = docArgs.length ? `(${docArgs.join(',')})` : ''; const opStr = opArgs.length ? `(${opArgs.join(',')})` : ''; @@ -961,3 +989,130 @@ export function getIdentifierValue( return idOrPk; } + +const RESERVED_SUBSCRIPTION_VARIABLE_NAMES = new Set([ + 'input', + 'condition', + 'filter', + 'owner', + 'and', + 'or', + 'not', + 'eq', + 'ne', + 'gt', + 'ge', + 'lt', + 'le', + 'contains', + 'notContains', + 'beginsWith', + 'between', + 'in', + 'notIn', + 'limit', + 'nextToken', + 'sortDirection', +]); + +export function processSubscriptionVariables( + model: SchemaModel, + operation: TransformerMutationType, + modelVariables: + | Record + | ((operation: TransformerMutationType) => Record) + | undefined, + cache: WeakMap< + SchemaModel, + Map | null> + >, +): Record | undefined { + if (!modelVariables) { + return undefined; + } + + let modelCache = cache.get(model); + if (!modelCache) { + modelCache = new Map(); + cache.set(model, modelCache); + } + + if (modelCache.has(operation)) { + const cached = modelCache.get(operation); + + return cached || undefined; + } + + let vars: Record; + + if (typeof modelVariables === 'function') { + try { + vars = modelVariables(operation); + } catch (error) { + logger.warn( + `Error evaluating subscriptionVariables function for model ${model.name}:`, + error, + ); + modelCache.set(operation, null); + + return undefined; + } + } else { + vars = modelVariables; + } + + const sanitized = sanitizeSubscriptionVariables(vars, model.name); + modelCache.set(operation, sanitized); + + return sanitized || undefined; +} + +function sanitizeSubscriptionVariables( + vars: any, + modelName: string, +): Record | null { + if (vars === null || typeof vars !== 'object' || Array.isArray(vars)) { + logger.warn( + `subscriptionVariables must be an object for model ${modelName}`, + ); + + return null; + } + + try { + const cloned = JSON.parse(JSON.stringify(vars)); + + return filterReservedSubscriptionVariableKeys(cloned); + } catch { + return filterReservedSubscriptionVariableKeys({ ...vars }); + } +} + +function filterReservedSubscriptionVariableKeys( + vars: Record, +): Record | null { + const result: Record = {}; + + try { + Object.entries(vars).forEach(([key, value]) => { + if (!RESERVED_SUBSCRIPTION_VARIABLE_NAMES.has(key)) { + result[key] = value; + } else { + logger.warn( + `Ignoring reserved GraphQL variable name '${key}' in subscription variables`, + ); + } + }); + } catch { + for (const key in vars) { + if ( + Object.prototype.hasOwnProperty.call(vars, key) && + !RESERVED_SUBSCRIPTION_VARIABLE_NAMES.has(key) + ) { + result[key] = vars[key]; + } + } + } + + return Object.keys(result).length > 0 ? result : null; +} diff --git a/packages/datastore/src/types.ts b/packages/datastore/src/types.ts index 1bd59d78f6c..4686dc04333 100644 --- a/packages/datastore/src/types.ts +++ b/packages/datastore/src/types.ts @@ -1044,6 +1044,11 @@ export interface DataStoreConfig { syncExpressions?: SyncExpression[]; authProviders?: AuthProviders; storageAdapter?: Adapter; + subscriptionVariables?: Record< + string, + | Record + | ((operation: 'CREATE' | 'UPDATE' | 'DELETE') => Record) + >; }; authModeStrategyType?: AuthModeStrategyType; conflictHandler?: ConflictHandler; // default : retry until client wins up to x times