diff --git a/row-level-security/script.ts b/row-level-security/script.ts index 233ad30e..d8e26bdd 100644 --- a/row-level-security/script.ts +++ b/row-level-security/script.ts @@ -18,22 +18,61 @@ function bypassRLS() { ); } -function forCompany(companyId: string) { - return Prisma.defineExtension((prisma) => - prisma.$extends({ + +export function forCompany(companyId: string) { + return Prisma.defineExtension((prisma) => { + // Store the original $transaction method + const originalTransaction = prisma.$transaction.bind(prisma); + + return prisma.$extends({ query: { - $allModels: { - async $allOperations({ args, query }) { - const [, result] = await prisma.$transaction([ + async $allOperations({ model, args, query }) { + try { + // Check if we're already inside a transaction. + const internalParams = (arguments[0] as any).__internalParams; + + if (internalParams?.transaction != null) { + // We're already in a transaction, just execute the query. + // The RLS config should have been set by the transaction wrapper. + return query(args); + } + + // Not in a transaction, wrap in a batch transaction as before. + const [, result] = await originalTransaction([ prisma.$executeRaw`SELECT set_config('app.current_company_id', ${companyId}, TRUE)`, query(args), ]); - return result; - }, + } }, }, - }) - ); + client: { + // Override $transaction to handle RLS setup. + $transaction: ((...args: Parameters) => { + const [input, options] = args; + + // Check if it's an interactive transaction (function passed). + if (typeof input === "function") { + return originalTransaction( + async (tx: Parameters[0]) => { + await tx.$executeRaw`SELECT set_config('app.current_company_id', ${companyId}, TRUE)`; + return input(tx); + }, + options + ); + } else { + const batch = Array.isArray(input) ? input : [input]; + return originalTransaction( + [ + prisma.$executeRaw`SELECT set_config('app.current_company_id', ${companyId}, TRUE)`, + ...batch, + ], + options + ); + } + }) as typeof prisma.$transaction, + }, + }); + }); } const prisma = new PrismaClient();