Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 49 additions & 10 deletions row-level-security/script.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,61 @@ function bypassRLS() {
);
}

function forCompany(companyId: string) {
return Prisma.defineExtension((prisma) =>
prisma.$extends({

export function forCompany(companyId: string) {
return Prisma.defineExtension((prisma) => {
// Store the original $transaction method
const originalTransaction = prisma.$transaction.bind(prisma);

return prisma.$extends({
query: {
$allModels: {
async $allOperations({ args, query }) {
const [, result] = await prisma.$transaction([
async $allOperations({ model, args, query }) {
try {
// Check if we're already inside a transaction.
const internalParams = (arguments[0] as any).__internalParams;

if (internalParams?.transaction != null) {
// We're already in a transaction, just execute the query.
// The RLS config should have been set by the transaction wrapper.
return query(args);
}

// Not in a transaction, wrap in a batch transaction as before.
const [, result] = await originalTransaction([
prisma.$executeRaw`SELECT set_config('app.current_company_id', ${companyId}, TRUE)`,
query(args),
]);
return result;
},
}
},
},
})
);
client: {
// Override $transaction to handle RLS setup.
$transaction: ((...args: Parameters<typeof prisma.$transaction>) => {
const [input, options] = args;

// Check if it's an interactive transaction (function passed).
if (typeof input === "function") {
return originalTransaction(
async (tx: Parameters<typeof input>[0]) => {
await tx.$executeRaw`SELECT set_config('app.current_company_id', ${companyId}, TRUE)`;
return input(tx);
},
options
);
} else {
const batch = Array.isArray(input) ? input : [input];
return originalTransaction(
[
prisma.$executeRaw`SELECT set_config('app.current_company_id', ${companyId}, TRUE)`,
...batch,
],
options
);
}
}) as typeof prisma.$transaction,
},
});
});
}

const prisma = new PrismaClient();
Expand Down