Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions packages/electric-db-collection/src/electric.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import {
removeTagFromIndex,
tagMatchesPattern,
} from './tag-index'
import type { ColumnEncoder } from './sql-compiler'
import type {
MoveOutPattern,
MoveTag,
Expand Down Expand Up @@ -347,6 +348,7 @@ function createLoadSubsetDedupe<T extends Row<unknown>>({
write,
commit,
collectionId,
encodeColumnName,
}: {
stream: ShapeStream<T>
syncMode: ElectricSyncMode
Expand All @@ -359,17 +361,24 @@ function createLoadSubsetDedupe<T extends Row<unknown>>({
}) => void
commit: () => void
collectionId?: string
/**
* Optional function to encode column names (e.g., camelCase to snake_case).
* This is typically the `encode` function from shapeOptions.columnMapper.
*/
encodeColumnName?: ColumnEncoder
}): DeduplicatedLoadSubset | null {
// Eager mode doesn't need subset loading
if (syncMode === `eager`) {
return null
}

const compileOptions = encodeColumnName ? { encodeColumnName } : undefined

const loadSubset = async (opts: LoadSubsetOptions) => {
// In progressive mode, use fetchSnapshot during snapshot phase
if (isBufferingInitialSync()) {
// Progressive mode snapshot phase: fetch and apply immediately
const snapshotParams = compileSQL<T>(opts)
const snapshotParams = compileSQL<T>(opts, compileOptions)
try {
const { data: rows } = await stream.fetchSnapshot(snapshotParams)

Expand Down Expand Up @@ -428,7 +437,10 @@ function createLoadSubsetDedupe<T extends Row<unknown>>({
orderBy,
// No limit - get all ties
}
const whereCurrentParams = compileSQL<T>(whereCurrentOpts)
const whereCurrentParams = compileSQL<T>(
whereCurrentOpts,
compileOptions,
)
promises.push(stream.requestSnapshot(whereCurrentParams))

debug(
Expand All @@ -442,7 +454,7 @@ function createLoadSubsetDedupe<T extends Row<unknown>>({
orderBy,
limit,
}
const whereFromParams = compileSQL<T>(whereFromOpts)
const whereFromParams = compileSQL<T>(whereFromOpts, compileOptions)
promises.push(stream.requestSnapshot(whereFromParams))

debug(
Expand All @@ -453,7 +465,7 @@ function createLoadSubsetDedupe<T extends Row<unknown>>({
await Promise.all(promises)
} else {
// No cursor - standard single request
const snapshotParams = compileSQL<T>(opts)
const snapshotParams = compileSQL<T>(opts, compileOptions)
await stream.requestSnapshot(snapshotParams)
}
}
Expand Down Expand Up @@ -1296,6 +1308,9 @@ function createElectricSync<T extends Row<unknown>>(
write,
commit,
collectionId,
// Pass the columnMapper's encode function to transform column names
// (e.g., camelCase to snake_case) when compiling SQL for subset queries
encodeColumnName: shapeOptions.columnMapper?.encode,
})

unsubscribeStream = stream.subscribe((messages: Array<Message<T>>) => {
Expand Down
69 changes: 56 additions & 13 deletions packages/electric-db-collection/src/sql-compiler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,42 @@ export type CompiledSqlRecord = Omit<SubsetParams, `params`> & {
params?: Array<unknown>
}

export function compileSQL<T>(options: LoadSubsetOptions): SubsetParams {
/**
* Optional function to encode column names (e.g., camelCase to snake_case)
* This is typically the `encode` function from a columnMapper
*/
export type ColumnEncoder = (columnName: string) => string

/**
* Options for SQL compilation
*/
export interface CompileSQLOptions {
/**
* Optional function to encode column names before quoting.
* Used to transform property names (e.g., camelCase) to database column names (e.g., snake_case).
* This should be the `encode` function from shapeOptions.columnMapper.
*/
encodeColumnName?: ColumnEncoder
}

export function compileSQL<T>(
options: LoadSubsetOptions,
compileOptions?: CompileSQLOptions,
): SubsetParams {
const { where, orderBy, limit } = options
const encodeColumnName = compileOptions?.encodeColumnName

const params: Array<T> = []
const compiledSQL: CompiledSqlRecord = { params }

if (where) {
// TODO: this only works when the where expression's PropRefs directly reference a column of the collection
// doesn't work if it goes through aliases because then we need to know the entire query to be able to follow the reference until the base collection (cf. followRef function)
compiledSQL.where = compileBasicExpression(where, params)
compiledSQL.where = compileBasicExpression(where, params, encodeColumnName)
}

if (orderBy) {
compiledSQL.orderBy = compileOrderBy(orderBy, params)
compiledSQL.orderBy = compileOrderBy(orderBy, params, encodeColumnName)
}

if (limit) {
Expand Down Expand Up @@ -58,21 +80,28 @@ export function compileSQL<T>(options: LoadSubsetOptions): SubsetParams {
* Quote PostgreSQL identifiers to handle mixed case column names correctly.
* Electric/Postgres requires quotes for case-sensitive identifiers.
* @param name - The identifier to quote
* @param encodeColumnName - Optional function to encode the column name before quoting (e.g., camelCase to snake_case)
* @returns The quoted identifier
*/
function quoteIdentifier(name: string): string {
return `"${name}"`
function quoteIdentifier(
name: string,
encodeColumnName?: ColumnEncoder,
): string {
const columnName = encodeColumnName ? encodeColumnName(name) : name
return `"${columnName}"`
}

/**
* Compiles the expression to a SQL string and mutates the params array with the values.
* @param exp - The expression to compile
* @param params - The params array
* @param encodeColumnName - Optional function to encode column names (e.g., camelCase to snake_case)
* @returns The compiled SQL string
*/
function compileBasicExpression(
exp: IR.BasicExpression<unknown>,
params: Array<unknown>,
encodeColumnName?: ColumnEncoder,
): string {
switch (exp.type) {
case `val`:
Expand All @@ -85,29 +114,34 @@ function compileBasicExpression(
`Compiler can't handle nested properties: ${exp.path.join(`.`)}`,
)
}
return quoteIdentifier(exp.path[0]!)
return quoteIdentifier(exp.path[0]!, encodeColumnName)
case `func`:
return compileFunction(exp, params)
return compileFunction(exp, params, encodeColumnName)
default:
throw new Error(`Unknown expression type`)
}
}

function compileOrderBy(orderBy: IR.OrderBy, params: Array<unknown>): string {
function compileOrderBy(
orderBy: IR.OrderBy,
params: Array<unknown>,
encodeColumnName?: ColumnEncoder,
): string {
const compiledOrderByClauses = orderBy.map((clause: IR.OrderByClause) =>
compileOrderByClause(clause, params),
compileOrderByClause(clause, params, encodeColumnName),
)
return compiledOrderByClauses.join(`,`)
}

function compileOrderByClause(
clause: IR.OrderByClause,
params: Array<unknown>,
encodeColumnName?: ColumnEncoder,
): string {
// FIXME: We should handle stringSort and locale.
// Correctly supporting them is tricky as it depends on Postgres' collation
const { expression, compareOptions } = clause
let sql = compileBasicExpression(expression, params)
let sql = compileBasicExpression(expression, params, encodeColumnName)

if (compareOptions.direction === `desc`) {
sql = `${sql} DESC`
Expand All @@ -134,6 +168,7 @@ function isNullValue(exp: IR.BasicExpression<unknown>): boolean {
function compileFunction(
exp: IR.Func<unknown>,
params: Array<unknown> = [],
encodeColumnName?: ColumnEncoder,
): string {
const { name, args } = exp

Expand All @@ -160,7 +195,7 @@ function compileFunction(
}

const compiledArgs = args.map((arg: IR.BasicExpression) =>
compileBasicExpression(arg, params),
compileBasicExpression(arg, params, encodeColumnName),
)

// Special case for IS NULL / IS NOT NULL - these are postfix operators
Expand All @@ -181,7 +216,11 @@ function compileFunction(
if (arg && arg.type === `func`) {
const funcArg = arg
if (funcArg.name === `isNull` || funcArg.name === `isUndefined`) {
const innerArg = compileBasicExpression(funcArg.args[0]!, params)
const innerArg = compileBasicExpression(
funcArg.args[0]!,
params,
encodeColumnName,
)
return `${innerArg} IS NOT NULL`
}
}
Expand Down Expand Up @@ -270,7 +309,11 @@ function compileFunction(
params.pop() // remove LHS (boolean)

// Recompile RHS to get fresh param
const rhsCompiled = compileBasicExpression(rhsArg!, params)
const rhsCompiled = compileBasicExpression(
rhsArg!,
params,
encodeColumnName,
)

// Transform: flip the comparison (val op col → col flipped_op val)
if (name === `lt`) {
Expand Down
115 changes: 115 additions & 0 deletions packages/electric-db-collection/tests/sql-compiler.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -308,5 +308,120 @@ describe(`sql-compiler`, () => {
expect(result.limit).toBe(10)
})
})

describe(`column name encoding (camelCase to snake_case)`, () => {
// Helper to simulate snakeCamelMapper's encode function
const camelToSnake = (str: string): string =>
str.replace(/[A-Z]/g, (letter) => `_${letter.toLowerCase()}`)

it(`should encode column names in where clause when encoder is provided`, () => {
const result = compileSQL(
{
where: func(`eq`, [ref(`programTemplateId`), val(`uuid-123`)]),
},
{ encodeColumnName: camelToSnake },
)
expect(result.where).toBe(`"program_template_id" = $1`)
expect(result.params).toEqual({ '1': `uuid-123` })
})

it(`should encode column names in compound where clauses`, () => {
const result = compileSQL(
{
where: func(`and`, [
func(`eq`, [ref(`programTemplateId`), val(`uuid-123`)]),
func(`gt`, [ref(`createdAt`), val(`2024-01-01`)]),
]),
},
{ encodeColumnName: camelToSnake },
)
expect(result.where).toBe(
`"program_template_id" = $1 AND "created_at" > $2`,
)
expect(result.params).toEqual({ '1': `uuid-123`, '2': `2024-01-01` })
})

it(`should encode column names in orderBy clause`, () => {
const result = compileSQL(
{
orderBy: [
{
expression: ref(`createdAt`),
compareOptions: { direction: `desc`, nulls: `last` },
},
],
},
{ encodeColumnName: camelToSnake },
)
expect(result.orderBy).toBe(`"created_at" DESC NULLS LAST`)
})

it(`should encode column names in isNull expressions`, () => {
const result = compileSQL(
{
where: func(`isNull`, [ref(`deletedAt`)]),
},
{ encodeColumnName: camelToSnake },
)
expect(result.where).toBe(`"deleted_at" IS NULL`)
})

it(`should encode column names in NOT isNull expressions`, () => {
const result = compileSQL(
{
where: func(`not`, [func(`isNull`, [ref(`archivedAt`)])]),
},
{ encodeColumnName: camelToSnake },
)
expect(result.where).toBe(`"archived_at" IS NOT NULL`)
})

it(`should not transform column names when no encoder is provided`, () => {
const result = compileSQL({
where: func(`eq`, [ref(`programTemplateId`), val(`uuid-123`)]),
})
// Without encoder, camelCase name is preserved
expect(result.where).toBe(`"programTemplateId" = $1`)
})

it(`should handle complex nested expressions with encoding`, () => {
const result = compileSQL(
{
where: func(`and`, [
func(`eq`, [ref(`userId`), val(`user-1`)]),
func(`or`, [
func(`eq`, [ref(`accountType`), val(`premium`)]),
func(`gte`, [ref(`totalSpend`), val(1000)]),
]),
]),
},
{ encodeColumnName: camelToSnake },
)
expect(result.where).toBe(
`"user_id" = $1 AND "account_type" = $2 OR "total_spend" >= $3`,
)
})

it(`should encode column names in LIKE expressions`, () => {
const result = compileSQL(
{
where: func(`ilike`, [ref(`firstName`), val(`%john%`)]),
},
{ encodeColumnName: camelToSnake },
)
expect(result.where).toBe(`"first_name" ILIKE $1`)
})

it(`should work with already snake_case names (identity transform)`, () => {
const result = compileSQL(
{
where: func(`eq`, [ref(`user_id`), val(`123`)]),
},
{ encodeColumnName: camelToSnake },
)
// snake_case input remains snake_case
expect(result.where).toBe(`"user_id" = $1`)
})
})
})
})
15 changes: 15 additions & 0 deletions packages/react-db/src/useLiveQuery.ts
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,21 @@ export function useLiveQuery(

if (needsNewCollection) {
if (isCollection) {
// Warn when passing a collection directly with on-demand sync mode
// In on-demand mode, data is only loaded when queries with predicates request it
// Passing the collection directly doesn't provide any predicates, so no data loads
const syncMode = (
configOrQueryOrCollection as { config?: { syncMode?: string } }
).config?.syncMode
if (syncMode === `on-demand`) {
console.warn(
`[useLiveQuery] Warning: Passing a collection with syncMode "on-demand" directly to useLiveQuery ` +
`will not load any data. In on-demand mode, data is only loaded when queries with predicates request it.\n\n` +
`Instead, use a query builder function:\n` +
` const { data } = useLiveQuery((q) => q.from({ c: myCollection }).select(({ c }) => c))\n\n` +
`Or switch to syncMode "eager" if you want all data to sync automatically.`,
)
}
// It's already a collection, ensure sync is started for React hooks
configOrQueryOrCollection.startSyncImmediate()
collectionRef.current = configOrQueryOrCollection
Expand Down
Loading
Loading