diff --git a/.changeset/violet-garlics-know.md b/.changeset/violet-garlics-know.md new file mode 100644 index 000000000..cb973e611 --- /dev/null +++ b/.changeset/violet-garlics-know.md @@ -0,0 +1,5 @@ +--- +'@powersync/service-sync-rules': minor +--- + +Support json_each as a table-valued function. diff --git a/packages/sync-rules/src/SqlParameterQuery.ts b/packages/sync-rules/src/SqlParameterQuery.ts index 41c083d58..92c6cf527 100644 --- a/packages/sync-rules/src/SqlParameterQuery.ts +++ b/packages/sync-rules/src/SqlParameterQuery.ts @@ -24,6 +24,7 @@ import { } from './types.js'; import { filterJsonRow, getBucketId, isJsonValue, isSelectStatement } from './utils.js'; import { SyncRulesOptions } from './SqlSyncRules.js'; +import { TableValuedFunctionSqlParameterQuery } from './TableValuedFunctionSqlParameterQuery.js'; /** * Represents a parameter query, such as: @@ -57,11 +58,16 @@ export class SqlParameterQuery { rows.errors.push(...checkUnsupportedFeatures(sql, q)); - if (q.from.length != 1 || q.from[0].type != 'table') { + if (q.from.length != 1) { throw new SqlRuleError('Must SELECT from a single table', sql, q.from?.[0]._location); + } else if (q.from[0].type == 'call') { + const from = q.from[0]; + return TableValuedFunctionSqlParameterQuery.fromSql(descriptor_name, sql, from, q, options); + } else if (q.from[0].type == 'statement') { + throw new SqlRuleError('Subqueries are not supported yet', sql, q.from?.[0]._location); } - const tableRef = q.from?.[0].name; + const tableRef = q.from[0].name; if (tableRef?.name == null) { throw new SqlRuleError('Must SELECT from a single table', sql, q.from?.[0]._location); } diff --git a/packages/sync-rules/src/TableValuedFunctionSqlParameterQuery.ts b/packages/sync-rules/src/TableValuedFunctionSqlParameterQuery.ts new file mode 100644 index 000000000..5537fe2bb --- /dev/null +++ b/packages/sync-rules/src/TableValuedFunctionSqlParameterQuery.ts @@ -0,0 +1,196 @@ +import { FromCall, SelectedColumn, SelectFromStatement } from 'pgsql-ast-parser'; +import { SqlRuleError } from './errors.js'; +import { SqlTools } from './sql_filters.js'; +import { checkUnsupportedFeatures, isClauseError, isParameterValueClause, sqliteBool } from './sql_support.js'; +import { TABLE_VALUED_FUNCTIONS, TableValuedFunction } from './TableValuedFunctions.js'; +import { + ParameterValueClause, + ParameterValueSet, + QueryParseOptions, + RequestParameters, + SqliteJsonValue, + SqliteRow +} from './types.js'; +import { getBucketId, isJsonValue } from './utils.js'; + +/** + * Represents a parameter query using a table-valued function. + * + * Right now this only supports json_each: + * + * SELECT json_each.value as v FROM json_each(request.parameters() -> 'array') + * + * This can currently not be combined with parameter table queries or multiple table-valued functions. + */ +export class TableValuedFunctionSqlParameterQuery { + static fromSql( + descriptor_name: string, + sql: string, + call: FromCall, + q: SelectFromStatement, + options?: QueryParseOptions + ): TableValuedFunctionSqlParameterQuery { + const query = new TableValuedFunctionSqlParameterQuery(); + + query.errors.push(...checkUnsupportedFeatures(sql, q)); + + if (!(call.function.name in TABLE_VALUED_FUNCTIONS)) { + query.errors.push(new SqlRuleError(`Table-valued function ${call.function.name} is not defined.`, sql, call)); + return query; + } + + const callTable = call.alias?.name ?? call.function.name; + const callExpression = call.args[0]; + + const tools = new SqlTools({ + table: callTable, + parameter_tables: ['token_parameters', 'user_parameters', callTable], + supports_parameter_expressions: true, + sql + }); + const where = q.where; + + const filter = tools.compileParameterValueExtractor(where); + const callClause = tools.compileParameterValueExtractor(callExpression); + const columns = q.columns ?? []; + const bucket_parameters = columns.map((column) => tools.getOutputName(column)); + + query.sql = sql; + query.descriptor_name = descriptor_name; + query.bucket_parameters = bucket_parameters; + query.columns = columns; + query.tools = tools; + query.function = TABLE_VALUED_FUNCTIONS[call.function.name]!; + query.callTableName = callTable; + if (!isClauseError(callClause)) { + query.callClause = callClause; + } + if (!isClauseError(filter)) { + query.filter = filter; + } + + for (let column of columns) { + if (column.alias != null) { + tools.checkSpecificNameCase(column.alias); + } + const name = tools.getSpecificOutputName(column); + const extractor = tools.compileParameterValueExtractor(column.expr); + if (isClauseError(extractor)) { + // Error logged already + continue; + } + query.parameter_extractors[name] = extractor; + } + + query.errors.push(...tools.errors); + + if (query.usesDangerousRequestParameters && !options?.accept_potentially_dangerous_queries) { + let err = new SqlRuleError( + "Potentially dangerous query based on parameters set by the client. The client can send any value for these parameters so it's not a good place to do authorization.", + sql + ); + err.type = 'warning'; + query.errors.push(err); + } + return query; + } + + sql?: string; + columns?: SelectedColumn[]; + parameter_extractors: Record = {}; + descriptor_name?: string; + /** _Output_ bucket parameters */ + bucket_parameters?: string[]; + id?: string; + tools?: SqlTools; + + filter?: ParameterValueClause; + callClause?: ParameterValueClause; + function?: TableValuedFunction; + callTableName?: string; + + errors: SqlRuleError[] = []; + + getStaticBucketIds(parameters: RequestParameters): string[] { + if (this.filter == null || this.callClause == null) { + // Error in filter clause + return []; + } + + const valueString = this.callClause.lookupParameterValue(parameters); + const rows = this.function!.call([valueString]); + let total: string[] = []; + for (let row of rows) { + total.push(...this.getIndividualBucketIds(row, parameters)); + } + return total; + } + + private getIndividualBucketIds(row: SqliteRow, parameters: RequestParameters): string[] { + const mergedParams: ParameterValueSet = { + raw_token_payload: parameters.raw_token_payload, + raw_user_parameters: parameters.raw_user_parameters, + user_id: parameters.user_id, + lookup: (table, column) => { + if (table == this.callTableName) { + return row[column]!; + } else { + return parameters.lookup(table, column); + } + } + }; + const filterValue = this.filter!.lookupParameterValue(mergedParams); + if (sqliteBool(filterValue) === 0n) { + return []; + } + + let result: Record = {}; + for (let name of this.bucket_parameters!) { + const value = this.parameter_extractors[name].lookupParameterValue(mergedParams); + if (isJsonValue(value)) { + result[`bucket.${name}`] = value; + } else { + throw new Error(`Invalid parameter value: ${value}`); + } + } + + return [getBucketId(this.descriptor_name!, this.bucket_parameters!, result)]; + } + + get hasAuthenticatedBucketParameters(): boolean { + // select where request.jwt() ->> 'role' == 'authorized' + // we do not count this as a sufficient check + // const authenticatedFilter = this.filter!.usesAuthenticatedRequestParameters; + + // select request.user_id() as user_id + const authenticatedExtractor = + Object.values(this.parameter_extractors).find( + (clause) => isParameterValueClause(clause) && clause.usesAuthenticatedRequestParameters + ) != null; + + // select value from json_each(request.jwt() ->> 'project_ids') + const authenticatedArgument = this.callClause?.usesAuthenticatedRequestParameters ?? false; + + return authenticatedExtractor || authenticatedArgument; + } + + get usesUnauthenticatedRequestParameters(): boolean { + // select where request.parameters() ->> 'include_comments' + const unauthenticatedFilter = this.filter?.usesUnauthenticatedRequestParameters; + + // select request.parameters() ->> 'project_id' + const unauthenticatedExtractor = + Object.values(this.parameter_extractors).find( + (clause) => isParameterValueClause(clause) && clause.usesUnauthenticatedRequestParameters + ) != null; + + // select value from json_each(request.parameters() ->> 'project_ids') + const unauthenticatedArgument = this.callClause?.usesUnauthenticatedRequestParameters ?? false; + + return unauthenticatedFilter || unauthenticatedExtractor || unauthenticatedArgument; + } + + get usesDangerousRequestParameters() { + return this.usesUnauthenticatedRequestParameters && !this.hasAuthenticatedBucketParameters; + } +} diff --git a/packages/sync-rules/src/TableValuedFunctions.ts b/packages/sync-rules/src/TableValuedFunctions.ts new file mode 100644 index 000000000..e3e40165c --- /dev/null +++ b/packages/sync-rules/src/TableValuedFunctions.ts @@ -0,0 +1,45 @@ +import { SqliteJsonValue, SqliteRow, SqliteValue } from './types.js'; +import { jsonValueToSqlite } from './utils.js'; + +export interface TableValuedFunction { + readonly name: string; + call: (args: SqliteValue[]) => SqliteRow[]; + detail: string; + documentation: string; +} + +export const JSON_EACH: TableValuedFunction = { + name: 'json_each', + call(args: SqliteValue[]) { + if (args.length != 1) { + throw new Error(`json_each expects 1 argument, got ${args.length}`); + } + const valueString = args[0]; + if (valueString === null) { + return []; + } else if (typeof valueString !== 'string') { + throw new Error(`Expected json_each to be called with a string, got ${valueString}`); + } + let values: SqliteJsonValue[] = []; + try { + values = JSON.parse(valueString); + } catch (e) { + throw new Error('Expected JSON string'); + } + if (!Array.isArray(values)) { + throw new Error('Expected an array'); + } + + return values.map((v) => { + return { + value: jsonValueToSqlite(v) + }; + }); + }, + detail: 'Each element of a JSON array', + documentation: 'Returns each element of a JSON array as a separate row.' +}; + +export const TABLE_VALUED_FUNCTIONS: Record = { + json_each: JSON_EACH +}; diff --git a/packages/sync-rules/src/request_functions.ts b/packages/sync-rules/src/request_functions.ts index 941daeb98..b99c88911 100644 --- a/packages/sync-rules/src/request_functions.ts +++ b/packages/sync-rules/src/request_functions.ts @@ -1,9 +1,9 @@ import { ExpressionType } from './ExpressionType.js'; -import { RequestParameters, SqliteValue } from './types.js'; +import { ParameterValueSet, SqliteValue } from './types.js'; export interface SqlParameterFunction { readonly debugName: string; - call: (parameters: RequestParameters) => SqliteValue; + call: (parameters: ParameterValueSet) => SqliteValue; getReturnType(): ExpressionType; /** request.user_id(), request.jwt(), token_parameters.* */ usesAuthenticatedRequestParameters: boolean; @@ -15,7 +15,7 @@ export interface SqlParameterFunction { const request_parameters: SqlParameterFunction = { debugName: 'request.parameters', - call(parameters: RequestParameters) { + call(parameters: ParameterValueSet) { return parameters.raw_user_parameters; }, getReturnType() { @@ -30,7 +30,7 @@ const request_parameters: SqlParameterFunction = { const request_jwt: SqlParameterFunction = { debugName: 'request.jwt', - call(parameters: RequestParameters) { + call(parameters: ParameterValueSet) { return parameters.raw_token_payload; }, getReturnType() { @@ -44,7 +44,7 @@ const request_jwt: SqlParameterFunction = { const request_user_id: SqlParameterFunction = { debugName: 'request.user_id', - call(parameters: RequestParameters) { + call(parameters: ParameterValueSet) { return parameters.user_id; }, getReturnType() { diff --git a/packages/sync-rules/src/sql_filters.ts b/packages/sync-rules/src/sql_filters.ts index 5cf9961f3..14b41c824 100644 --- a/packages/sync-rules/src/sql_filters.ts +++ b/packages/sync-rules/src/sql_filters.ts @@ -500,7 +500,8 @@ export class SqlTools { if (expr.type != 'ref') { return false; } - return this.parameter_tables.includes(expr.table?.name ?? ''); + const tableName = expr.table?.name ?? this.default_table; + return this.parameter_tables.includes(tableName ?? ''); } /** @@ -585,13 +586,12 @@ export class SqlTools { } getParameterRefClause(expr: ExprRef): ParameterValueClause { - const table = expr.table!.name; + const table = (expr.table?.name ?? this.default_table)!; const column = expr.name; return { key: `${table}.${column}`, lookupParameterValue: (parameters) => { - const pt: SqliteJsonRow | undefined = (parameters as any)[table]; - return pt?.[column] ?? null; + return parameters.lookup(table, column); }, usesAuthenticatedRequestParameters: table == 'token_parameters', usesUnauthenticatedRequestParameters: table == 'user_parameters' @@ -607,18 +607,17 @@ export class SqlTools { * * Only "value" tables are supported here, not parameter values. */ - getTableName(ref: ExprRef) { + getTableName(ref: ExprRef): string { if (this.refHasSchema(ref)) { throw new SqlRuleError(`Specifying schema in column references is not supported`, this.sql, ref); } - if (ref.table?.name == null && this.default_table != null) { - return this.default_table; - } else if (this.value_tables.includes(ref.table?.name ?? '')) { - return ref.table!.name; + const tableName = ref.table?.name ?? this.default_table; + if (this.value_tables.includes(tableName ?? '')) { + return tableName!; } else if (ref.table?.name == null) { throw new SqlRuleError(`Table name required`, this.sql, ref); } else { - throw new SqlRuleError(`Undefined table ${ref.table?.name}`, this.sql, ref); + throw new SqlRuleError(`Undefined table ${tableName}`, this.sql, ref); } } @@ -750,6 +749,8 @@ function staticValue(expr: Expr): SqliteValue { return expr.value ? SQLITE_TRUE : SQLITE_FALSE; } else if (expr.type == 'integer') { return BigInt(expr.value); + } else if (expr.type == 'null') { + return null; } else { return (expr as any).value; } diff --git a/packages/sync-rules/src/sql_functions.ts b/packages/sync-rules/src/sql_functions.ts index fd6f71b91..62ebc7bae 100644 --- a/packages/sync-rules/src/sql_functions.ts +++ b/packages/sync-rules/src/sql_functions.ts @@ -823,9 +823,6 @@ export function jsonExtract(sourceValue: SqliteValue, path: SqliteValue, operato if (operator == '->') { // -> must always stringify return JSONBig.stringify(value); - } else if (typeof value == 'object' || Array.isArray(value)) { - // Objects and arrays must be stringified - return JSONBig.stringify(value); } else { // Plain scalar value - simple conversion. return jsonValueToSqlite(value as string | number | bigint | boolean | null); diff --git a/packages/sync-rules/src/types.ts b/packages/sync-rules/src/types.ts index 5aa8ea378..b506c09c4 100644 --- a/packages/sync-rules/src/types.ts +++ b/packages/sync-rules/src/types.ts @@ -73,7 +73,23 @@ export interface RequestJwtPayload { [key: string]: any; } -export class RequestParameters { +export interface ParameterValueSet { + lookup(table: string, column: string): SqliteValue; + + /** + * JSON string of raw request parameters. + */ + raw_user_parameters: string; + + /** + * JSON string of raw request parameters. + */ + raw_token_payload: string; + + user_id: string; +} + +export class RequestParameters implements ParameterValueSet { token_parameters: SqliteJsonRow; user_parameters: SqliteJsonRow; @@ -106,6 +122,15 @@ export class RequestParameters { this.raw_user_parameters = JSONBig.stringify(clientParameters); this.user_parameters = toSyncRulesParameters(clientParameters); } + + lookup(table: string, column: string): SqliteJsonValue { + if (table == 'token_parameters') { + return this.token_parameters[column]; + } else if (table == 'user_parameters') { + return this.user_parameters[column]; + } + throw new Error(`Unknown table: ${table}`); + } } /** @@ -200,7 +225,7 @@ export interface InputParameter { * * Only relevant for parameter queries. */ - parametersToLookupValue(parameters: RequestParameters): SqliteValue; + parametersToLookupValue(parameters: ParameterValueSet): SqliteValue; } export interface EvaluateRowOptions { @@ -276,7 +301,7 @@ export interface ParameterValueClause { * * Only relevant for parameter queries. */ - lookupParameterValue(parameters: RequestParameters): SqliteValue; + lookupParameterValue(parameters: ParameterValueSet): SqliteValue; } export interface QuerySchema { diff --git a/packages/sync-rules/src/utils.ts b/packages/sync-rules/src/utils.ts index 3cf15bc6e..06e34f540 100644 --- a/packages/sync-rules/src/utils.ts +++ b/packages/sync-rules/src/utils.ts @@ -51,9 +51,12 @@ export function filterJsonRow(data: SqliteRow): SqliteJsonRow { * * Types specifically not supported in output are `boolean` and `undefined`. */ -export function jsonValueToSqlite(value: null | undefined | string | number | bigint | boolean): SqliteValue { +export function jsonValueToSqlite(value: null | undefined | string | number | bigint | boolean | any): SqliteValue { if (typeof value == 'boolean') { return value ? SQLITE_TRUE : SQLITE_FALSE; + } else if (typeof value == 'object' || Array.isArray(value)) { + // Objects and arrays must be stringified + return JSONBig.stringify(value); } else { return value ?? null; } diff --git a/packages/sync-rules/test/src/table_valued_function_queries.test.ts b/packages/sync-rules/test/src/table_valued_function_queries.test.ts new file mode 100644 index 000000000..aaa33cac4 --- /dev/null +++ b/packages/sync-rules/test/src/table_valued_function_queries.test.ts @@ -0,0 +1,146 @@ +import { describe, expect, test } from 'vitest'; +import { RequestParameters, SqlParameterQuery } from '../../src/index.js'; +import { StaticSqlParameterQuery } from '../../src/StaticSqlParameterQuery.js'; +import { PARSE_OPTIONS } from './util.js'; + +describe('table-valued function queries', () => { + test('json_each(array param)', function () { + const sql = "SELECT json_each.value as v FROM json_each(request.parameters() -> 'array')"; + const query = SqlParameterQuery.fromSql('mybucket', sql, { + ...PARSE_OPTIONS, + accept_potentially_dangerous_queries: true + }) as StaticSqlParameterQuery; + expect(query.errors).toEqual([]); + expect(query.bucket_parameters).toEqual(['v']); + + expect(query.getStaticBucketIds(new RequestParameters({ sub: '' }, { array: [1, 2, 3] }))).toEqual([ + 'mybucket[1]', + 'mybucket[2]', + 'mybucket[3]' + ]); + }); + + test('json_each(static string)', function () { + const sql = `SELECT json_each.value as v FROM json_each('[1,2,3]')`; + const query = SqlParameterQuery.fromSql('mybucket', sql, PARSE_OPTIONS) as StaticSqlParameterQuery; + expect(query.errors).toEqual([]); + expect(query.bucket_parameters).toEqual(['v']); + + expect(query.getStaticBucketIds(new RequestParameters({ sub: '' }, {}))).toEqual([ + 'mybucket[1]', + 'mybucket[2]', + 'mybucket[3]' + ]); + }); + + test('json_each(null)', function () { + const sql = `SELECT json_each.value as v FROM json_each(null)`; + const query = SqlParameterQuery.fromSql('mybucket', sql, PARSE_OPTIONS) as StaticSqlParameterQuery; + expect(query.errors).toEqual([]); + expect(query.bucket_parameters).toEqual(['v']); + + expect(query.getStaticBucketIds(new RequestParameters({ sub: '' }, {}))).toEqual([]); + }); + + test('json_each with fn alias', function () { + const sql = "SELECT e.value FROM json_each(request.parameters() -> 'array') e"; + const query = SqlParameterQuery.fromSql('mybucket', sql, { + ...PARSE_OPTIONS, + accept_potentially_dangerous_queries: true + }) as StaticSqlParameterQuery; + expect(query.errors).toEqual([]); + expect(query.bucket_parameters).toEqual(['value']); + + expect(query.getStaticBucketIds(new RequestParameters({ sub: '' }, { array: [1, 2, 3] }))).toEqual([ + 'mybucket[1]', + 'mybucket[2]', + 'mybucket[3]' + ]); + }); + + test('json_each with direct value', function () { + const sql = "SELECT value FROM json_each(request.parameters() -> 'array')"; + const query = SqlParameterQuery.fromSql('mybucket', sql, { + ...PARSE_OPTIONS, + accept_potentially_dangerous_queries: true + }) as StaticSqlParameterQuery; + expect(query.errors).toEqual([]); + expect(query.bucket_parameters).toEqual(['value']); + + expect(query.getStaticBucketIds(new RequestParameters({ sub: '' }, { array: [1, 2, 3] }))).toEqual([ + 'mybucket[1]', + 'mybucket[2]', + 'mybucket[3]' + ]); + }); + + test('json_each in filters (1)', function () { + const sql = "SELECT value as v FROM json_each(request.parameters() -> 'array') e WHERE e.value >= 2"; + const query = SqlParameterQuery.fromSql('mybucket', sql, { + ...PARSE_OPTIONS, + accept_potentially_dangerous_queries: true + }) as StaticSqlParameterQuery; + expect(query.errors).toEqual([]); + expect(query.bucket_parameters).toEqual(['v']); + + expect(query.getStaticBucketIds(new RequestParameters({ sub: '' }, { array: [1, 2, 3] }))).toEqual([ + 'mybucket[2]', + 'mybucket[3]' + ]); + }); + + test('json_each with nested json', function () { + const sql = + "SELECT value ->> 'id' as project_id FROM json_each(request.jwt() -> 'projects') WHERE (value ->> 'role') = 'admin'"; + const query = SqlParameterQuery.fromSql('mybucket', sql, { + ...PARSE_OPTIONS, + accept_potentially_dangerous_queries: true + }) as StaticSqlParameterQuery; + expect(query.errors).toEqual([]); + expect(query.bucket_parameters).toEqual(['project_id']); + + expect( + query.getStaticBucketIds( + new RequestParameters( + { + sub: '', + projects: [ + { id: 1, role: 'admin' }, + { id: 2, role: 'user' } + ] + }, + {} + ) + ) + ).toEqual(['mybucket[1]']); + }); + + describe('dangerous queries', function () { + function testDangerousQuery(sql: string) { + test(sql, function () { + const query = SqlParameterQuery.fromSql('mybucket', sql, PARSE_OPTIONS) as SqlParameterQuery; + expect(query.errors).toMatchObject([ + { + message: + "Potentially dangerous query based on parameters set by the client. The client can send any value for these parameters so it's not a good place to do authorization." + } + ]); + expect(query.usesDangerousRequestParameters).toEqual(true); + }); + } + function testSafeQuery(sql: string) { + test(sql, function () { + const query = SqlParameterQuery.fromSql('mybucket', sql, PARSE_OPTIONS) as SqlParameterQuery; + expect(query.errors).toEqual([]); + expect(query.usesDangerousRequestParameters).toEqual(false); + }); + } + + testSafeQuery('select value from json_each(request.user_id())'); + testDangerousQuery("select value from json_each(request.parameters() ->> 'project_ids')"); + testSafeQuery("select request.user_id() as user_id, value FROM json_each(request.parameters() ->> 'project_ids')"); + testSafeQuery( + "select request.parameters() ->> 'something' as something, value as project_id FROM json_each(request.jwt() ->> 'project_ids')" + ); + }); +});