diff --git a/package.json b/package.json index af2525b..61d60dd 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@outerbase/sdk", - "version": "2.0.0-rc.2", + "version": "2.0.0-rc.3", "description": "", "main": "dist/index.js", "module": "dist/index.js", diff --git a/src/connections/bigquery.ts b/src/connections/bigquery.ts index 8edcab4..04d8138 100644 --- a/src/connections/bigquery.ts +++ b/src/connections/bigquery.ts @@ -204,7 +204,7 @@ export class BigQueryConnection extends SqlConnection { * @param parameters - An object containing the parameters to be used in the query. * @returns Promise<{ data: any, error: Error | null }> */ - async query>( + async internalQuery>( query: Query ): Promise> { try { diff --git a/src/connections/index.ts b/src/connections/index.ts index 250da7b..4199a72 100644 --- a/src/connections/index.ts +++ b/src/connections/index.ts @@ -32,7 +32,10 @@ export abstract class Connection { // Retrieve metadata about the database, useful for introspection. abstract fetchDatabaseSchema(): Promise; - abstract raw(query: string): Promise; + abstract raw( + query: string, + params?: Record | unknown[] + ): Promise; abstract testConnection(): Promise<{ error?: string }>; // Connection common operations that will be used by Outerbase diff --git a/src/connections/motherduck.ts b/src/connections/motherduck.ts index d4dcc7b..06f181b 100644 --- a/src/connections/motherduck.ts +++ b/src/connections/motherduck.ts @@ -51,7 +51,7 @@ export class DuckDBConnection extends PostgreBaseConnection { * @param parameters - An object containing the parameters to be used in the query. * @returns Promise<{ data: any, error: Error | null }> */ - async query>( + async internalQuery>( query: Query ): Promise> { const connection = this.connection; diff --git a/src/connections/mysql.ts b/src/connections/mysql.ts index 6a002ac..33321b4 100644 --- a/src/connections/mysql.ts +++ b/src/connections/mysql.ts @@ -225,7 +225,7 @@ export class MySQLConnection extends SqlConnection { return super.mapDataType(dataType); } - async query>( + async internalQuery>( query: Query ): Promise> { try { diff --git a/src/connections/postgre/postgresql.ts b/src/connections/postgre/postgresql.ts index 6787e5b..4c36088 100644 --- a/src/connections/postgre/postgresql.ts +++ b/src/connections/postgre/postgresql.ts @@ -3,22 +3,16 @@ import { QueryResult } from '..'; import { Query } from '../../query'; import { AbstractDialect } from './../../query-builder'; import { PostgresDialect } from './../../query-builder/dialects/postgres'; -import { QueryType } from './../../query-params'; import { createErrorResult, transformArrayBasedResult, } from './../../utils/transformer'; import { PostgreBaseConnection } from './base'; -function replacePlaceholders(query: string): string { - let index = 1; - return query.replace(/\?/g, () => `$${index++}`); -} - export class PostgreSQLConnection extends PostgreBaseConnection { client: Client; dialect: AbstractDialect = new PostgresDialect(); - queryType: QueryType = QueryType.positional; + protected numberedPlaceholder = true; constructor(pgClient: any) { super(); @@ -33,15 +27,12 @@ export class PostgreSQLConnection extends PostgreBaseConnection { await this.client.end(); } - async query>( + async internalQuery>( query: Query ): Promise> { try { const { rows, fields } = await this.client.query({ - text: - query.parameters?.length === 0 - ? query.query - : replacePlaceholders(query.query), + text: query.query, rowMode: 'array', values: query.parameters as unknown[], }); diff --git a/src/connections/snowflake/snowflake.ts b/src/connections/snowflake/snowflake.ts index b4a9798..fbfd365 100644 --- a/src/connections/snowflake/snowflake.ts +++ b/src/connections/snowflake/snowflake.ts @@ -175,7 +175,7 @@ export class SnowflakeConnection extends PostgreBaseConnection { ); } - async query>( + async internalQuery>( query: Query ): Promise> { try { diff --git a/src/connections/sql-base.ts b/src/connections/sql-base.ts index 9cba01d..89d62eb 100644 --- a/src/connections/sql-base.ts +++ b/src/connections/sql-base.ts @@ -7,11 +7,16 @@ import { } from '..'; import { AbstractDialect, ColumnDataType } from './../query-builder'; import { TableColumn, TableColumnDefinition } from './../models/database'; +import { + namedPlaceholder, + toNumberedPlaceholders, +} from './../utils/placeholder'; export abstract class SqlConnection extends Connection { abstract dialect: AbstractDialect; + protected numberedPlaceholder = false; - abstract query>( + abstract internalQuery>( query: Query ): Promise>; @@ -21,8 +26,56 @@ export abstract class SqlConnection extends Connection { return dataType; } - async raw(query: string): Promise { - return await this.query({ query }); + /** + * This is a deprecated function, use raw instead. We keep this for + * backward compatibility. + * + * @deprecated + * @param query + * @returns + */ + async query>( + query: Query + ): Promise> { + return (await this.raw( + query.query, + query.parameters + )) as QueryResult; + } + + async raw( + query: string, + params?: Record | unknown[] + ): Promise { + if (!params) return await this.internalQuery({ query }); + + // Positional placeholder + if (Array.isArray(params)) { + if (this.numberedPlaceholder) { + const { query: newQuery, bindings } = toNumberedPlaceholders( + query, + params + ); + + return await this.internalQuery({ + query: newQuery, + parameters: bindings, + }); + } + + return await this.internalQuery({ query, parameters: params }); + } + + // Named placeholder + const { query: newQuery, bindings } = namedPlaceholder( + query, + params!, + this.numberedPlaceholder + ); + return await this.internalQuery({ + query: newQuery, + parameters: bindings, + }); } async select( diff --git a/src/connections/sqlite/cloudflare.ts b/src/connections/sqlite/cloudflare.ts index 0ac5190..adf529a 100644 --- a/src/connections/sqlite/cloudflare.ts +++ b/src/connections/sqlite/cloudflare.ts @@ -107,7 +107,7 @@ export class CloudflareD1Connection extends SqliteBaseConnection { * @param parameters - An object containing the parameters to be used in the query. * @returns Promise<{ data: any, error: Error | null }> */ - async query>( + async internalQuery>( query: Query ): Promise> { if (!this.apiKey) throw new Error('Cloudflare API key is not set'); diff --git a/src/connections/sqlite/starbase.ts b/src/connections/sqlite/starbase.ts index 0d8ee5f..98bba73 100644 --- a/src/connections/sqlite/starbase.ts +++ b/src/connections/sqlite/starbase.ts @@ -90,7 +90,7 @@ export class StarbaseConnection extends SqliteBaseConnection { * @param parameters - An object containing the parameters to be used in the query. * @returns Promise<{ data: any, error: Error | null }> */ - async query>( + async internalQuery>( query: Query ): Promise> { if (!this.url) throw new Error('Starbase URL is not set'); diff --git a/src/connections/sqlite/turso.ts b/src/connections/sqlite/turso.ts index f323608..4e03cd2 100644 --- a/src/connections/sqlite/turso.ts +++ b/src/connections/sqlite/turso.ts @@ -16,7 +16,7 @@ export class TursoConnection extends SqliteBaseConnection { this.client = client; } - async query>( + async internalQuery>( query: Query ): Promise> { try { diff --git a/src/utils/placeholder.ts b/src/utils/placeholder.ts new file mode 100644 index 0000000..ad240eb --- /dev/null +++ b/src/utils/placeholder.ts @@ -0,0 +1,157 @@ +const RE_PARAM = /(?:\?)|(?::(\d+|(?:[a-zA-Z][a-zA-Z0-9_]*)))/g, + DQUOTE = 34, + SQUOTE = 39, + BSLASH = 92; + +/** + * This code is based on https://github.com/mscdex/node-mariasql/blob/master/lib/Client.js#L296-L420 + * License: https://github.com/mscdex/node-mariasql/blob/master/LICENSE + * + * @param query + * @returns + */ +function parse(query: string): [string] | [string[], (string | number)[]] { + let ppos = RE_PARAM.exec(query); + let curpos = 0; + let start = 0; + let end; + const parts = []; + let inQuote = false; + let escape = false; + let qchr; + const tokens = []; + let qcnt = 0; + let lastTokenEndPos = 0; + let i; + + if (ppos) { + do { + for (i = curpos, end = ppos.index; i < end; ++i) { + let chr = query.charCodeAt(i); + if (chr === BSLASH) escape = !escape; + else { + if (escape) { + escape = false; + continue; + } + if (inQuote && chr === qchr) { + if (query.charCodeAt(i + 1) === qchr) { + // quote escaped via "" or '' + ++i; + continue; + } + inQuote = false; + } else if (!inQuote && (chr === DQUOTE || chr === SQUOTE)) { + inQuote = true; + qchr = chr; + } + } + } + if (!inQuote) { + parts.push(query.substring(start, end)); + tokens.push(ppos[0].length === 1 ? qcnt++ : ppos[1]); + start = end + ppos[0].length; + lastTokenEndPos = start; + } + curpos = end + ppos[0].length; + } while ((ppos = RE_PARAM.exec(query))); + + if (tokens.length) { + if (curpos < query.length) { + parts.push(query.substring(lastTokenEndPos)); + } + return [parts, tokens]; + } + } + return [query]; +} + +export function namedPlaceholder( + query: string, + params: Record, + numbered = false +): { query: string; bindings: unknown[] } { + const parts = parse(query); + + if (parts.length === 1) { + return { query, bindings: [] }; + } + + const bindings = []; + let newQuery = ''; + + const [sqlFragments, placeholders] = parts; + + // If placeholders contains any number, then it's a mix of named and numbered placeholders + if (placeholders.some((p) => typeof p === 'number')) { + throw new Error( + 'Mixing named and positional placeholder should throw error' + ); + } + + for (let i = 0; i < sqlFragments.length; i++) { + newQuery += sqlFragments[i]; + + if (i < placeholders.length) { + const key = placeholders[i]; + + if (numbered) { + newQuery += `$${i + 1}`; + } else { + newQuery += `?`; + } + + const placeholderValue = params[key]; + if (placeholderValue === undefined) { + throw new Error(`Missing value for placeholder ${key}`); + } + + bindings.push(params[key]); + } + } + + return { query: newQuery, bindings }; +} + +export function toNumberedPlaceholders( + query: string, + params: unknown[] +): { + query: string; + bindings: unknown[]; +} { + const parts = parse(query); + + if (parts.length === 1) { + return { query, bindings: [] }; + } + + const bindings = []; + let newQuery = ''; + + const [sqlFragments, placeholders] = parts; + + if (placeholders.length !== params.length) { + throw new Error( + 'Number of positional placeholder should match with the number of values' + ); + } + + // Mixing named and numbered placeholders should throw error + if (placeholders.some((p) => typeof p === 'string')) { + throw new Error( + 'Mixing named and positional placeholder should throw error' + ); + } + + for (let i = 0; i < sqlFragments.length; i++) { + newQuery += sqlFragments[i]; + + if (i < placeholders.length) { + newQuery += `$${i + 1}`; + bindings.push(params[i]); + } + } + + return { query: newQuery, bindings }; +} diff --git a/tests/connections/connection.test.ts b/tests/connections/connection.test.ts index 8c45dc6..8580605 100644 --- a/tests/connections/connection.test.ts +++ b/tests/connections/connection.test.ts @@ -36,6 +36,43 @@ function cleanup(data: Record[]) { } describe('Database Connection', () => { + test('Support named parameters', async () => { + if (process.env.CONNECTION_TYPE === 'mongodb') return; + + const sql = + process.env.CONNECTION_TYPE === 'mysql' + ? 'SELECT CONCAT(:hello, :world) AS testing_word' + : 'SELECT (:hello || :world) AS testing_word'; + + const { data } = await db.raw(sql, { + hello: 'hello ', + world: 'world', + }); + + if (process.env.CONNECTION_TYPE === 'snowflake') { + expect(data).toEqual([{ TESTING_WORD: 'hello world' }]); + } else { + expect(data).toEqual([{ testing_word: 'hello world' }]); + } + }); + + test('Support positional placeholder', async () => { + if (process.env.CONNECTION_TYPE === 'mongodb') return; + + const sql = + process.env.CONNECTION_TYPE === 'mysql' + ? 'SELECT CONCAT(?, ?) AS testing_word' + : 'SELECT (? || ?) AS testing_word'; + + const { data } = await db.raw(sql, ['hello ', 'world']); + + if (process.env.CONNECTION_TYPE === 'snowflake') { + expect(data).toEqual([{ TESTING_WORD: 'hello world' }]); + } else { + expect(data).toEqual([{ testing_word: 'hello world' }]); + } + }); + test('Create table', async () => { const { error: createTableTeamError } = await db.createTable( DEFAULT_SCHEMA, diff --git a/tests/units/placeholder.test.ts b/tests/units/placeholder.test.ts new file mode 100644 index 0000000..695dfb5 --- /dev/null +++ b/tests/units/placeholder.test.ts @@ -0,0 +1,104 @@ +import { + namedPlaceholder, + toNumberedPlaceholders, +} from './../../src/utils/placeholder'; + +test('Positional placeholder', () => { + expect( + namedPlaceholder('SELECT * FROM users WHERE id = :id AND age > :age', { + id: 1, + age: 50, + }) + ).toEqual({ + query: 'SELECT * FROM users WHERE id = ? AND age > ?', + bindings: [1, 50], + }); +}); + +test('Positional placeholder inside the string should be ignored', () => { + expect( + namedPlaceholder( + 'SELECT * FROM users WHERE name = :name AND email = ":email"', + { + name: 'John', + } + ) + ).toEqual({ + query: 'SELECT * FROM users WHERE name = ? AND email = ":email"', + bindings: ['John'], + }); +}); + +test('Named placeholder to number placeholder', () => { + expect( + namedPlaceholder( + 'SELECT * FROM users WHERE id = :id AND age > :age', + { + id: 1, + age: 30, + }, + true + ) + ).toEqual({ + query: 'SELECT * FROM users WHERE id = $1 AND age > $2', + bindings: [1, 30], + }); +}); + +test('Named placeholder to number placeholder with string', () => { + expect( + namedPlaceholder( + 'SELECT * FROM users WHERE id = :id AND email = ":email"', + { + id: 1, + }, + true + ) + ).toEqual({ + query: 'SELECT * FROM users WHERE id = $1 AND email = ":email"', + bindings: [1], + }); +}); + +test('Named placeholder with missing value should throw an error', () => { + expect(() => + namedPlaceholder('SELECT * FROM users WHERE id = :id AND age > :age', { + id: 1, + }) + ).toThrow(); +}); + +test('Number of positional placeholder should match with the number of values', () => { + expect(() => + toNumberedPlaceholders('SELECT * FROM users WHERE id = ? AND age > ?', [ + 1, + ]) + ).toThrow(); +}); + +test('Mixing named and positional placeholder should throw error', () => { + expect(() => + namedPlaceholder('SELECT * FROM users WHERE id = :id AND age > ?', { + id: 1, + }) + ).toThrow(); + + expect(() => { + toNumberedPlaceholders( + `SELECT * FROM users WHERE id = ? AND age > :age`, + [1, 30] + ); + }).toThrow(); +}); + +test('Convert positional placeholder to numbered placeholder', () => { + expect( + toNumberedPlaceholders( + `SELECT * FROM users WHERE id = ? AND email = '?' AND name = 'Outer""base' AND age > ?`, + [1, 30] + ) + ).toEqual({ + query: `SELECT * FROM users WHERE id = $1 AND email = '?' AND name = 'Outer""base' AND age > $2`, + bindings: [1, 30], + }); +});