Skip to content

Commit a2d734c

Browse files
Neon $withAuth (#3562)
* Added `` for `neon-http` driver, updated `@neondatabase/serverless` version, fixed `.catch` on `` requiring strict return type * Fixed package version mismatch --------- Co-authored-by: Andrii Sherman <andreysherman11@gmail.com>
1 parent fcaa0a5 commit a2d734c

File tree

15 files changed

+8090
-4559
lines changed

15 files changed

+8090
-4559
lines changed

drizzle-orm/package.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
"@electric-sql/pglite": ">=0.2.0",
5050
"@libsql/client": ">=0.10.0",
5151
"@libsql/client-wasm": ">=0.10.0",
52-
"@neondatabase/serverless": ">=0.1",
52+
"@neondatabase/serverless": ">=0.10.0",
5353
"@op-engineering/op-sqlite": ">=2",
5454
"@opentelemetry/api": "^1.4.1",
5555
"@planetscale/database": ">=1",
@@ -169,7 +169,7 @@
169169
"@libsql/client": "^0.10.0",
170170
"@libsql/client-wasm": "^0.10.0",
171171
"@miniflare/d1": "^2.14.4",
172-
"@neondatabase/serverless": "^0.9.0",
172+
"@neondatabase/serverless": "^0.10.0",
173173
"@op-engineering/op-sqlite": "^2.0.16",
174174
"@opentelemetry/api": "^1.4.1",
175175
"@originjs/vite-plugin-commonjs": "^1.0.3",

drizzle-orm/src/neon-http/driver.ts

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,67 @@ export class NeonHttpDriver {
4040
}
4141
}
4242

43+
function wrap<T extends object>(
44+
target: T,
45+
token: string,
46+
cb: (target: any, p: string | symbol, res: any) => any,
47+
deep?: boolean,
48+
) {
49+
return new Proxy(target, {
50+
get(target, p) {
51+
const element = target[p as keyof typeof p];
52+
if (typeof element !== 'function' && (typeof element !== 'object' || element === null)) return element;
53+
54+
if (deep) return wrap(element, token, cb);
55+
if (p === 'query') return wrap(element, token, cb, true);
56+
57+
return new Proxy(element as any, {
58+
apply(target, thisArg, argArray) {
59+
const res = target.call(thisArg, ...argArray);
60+
if ('setToken' in res && typeof res.setToken === 'function') {
61+
res.setToken(token);
62+
}
63+
return cb(target, p, res);
64+
},
65+
});
66+
},
67+
});
68+
}
69+
4370
export class NeonHttpDatabase<
4471
TSchema extends Record<string, unknown> = Record<string, never>,
4572
> extends PgDatabase<NeonHttpQueryResultHKT, TSchema> {
4673
static override readonly [entityKind]: string = 'NeonHttpDatabase';
4774

75+
$withAuth(
76+
token: string,
77+
): Omit<
78+
this,
79+
Exclude<
80+
keyof this,
81+
| '$count'
82+
| 'delete'
83+
| 'select'
84+
| 'selectDistinct'
85+
| 'selectDistinctOn'
86+
| 'update'
87+
| 'insert'
88+
| 'with'
89+
| 'query'
90+
| 'execute'
91+
| 'refreshMaterializedView'
92+
>
93+
> {
94+
this.authToken = token;
95+
96+
return wrap(this, token, (target, p, res) => {
97+
if (p === 'with') {
98+
return wrap(res, token, (_, __, res) => res);
99+
}
100+
return res;
101+
});
102+
}
103+
48104
/** @internal */
49105
declare readonly session: NeonHttpSession<TSchema, ExtractTablesWithRelations<TSchema>>;
50106

drizzle-orm/src/neon-http/session.ts

Lines changed: 54 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,18 +38,43 @@ export class NeonHttpPreparedQuery<T extends PreparedQueryConfig> extends PgPrep
3838
super(query);
3939
}
4040

41-
async execute(placeholderValues: Record<string, unknown> | undefined = {}): Promise<T['execute']> {
41+
async execute(placeholderValues: Record<string, unknown> | undefined): Promise<T['execute']>;
42+
/** @internal */
43+
async execute(placeholderValues: Record<string, unknown> | undefined, token?: string): Promise<T['execute']>;
44+
/** @internal */
45+
async execute(
46+
placeholderValues: Record<string, unknown> | undefined = {},
47+
token: string | undefined = this.authToken,
48+
): Promise<T['execute']> {
4249
const params = fillPlaceholders(this.query.params, placeholderValues);
4350

4451
this.logger.logQuery(this.query.sql, params);
4552

4653
const { fields, client, query, customResultMapper } = this;
4754

4855
if (!fields && !customResultMapper) {
49-
return client(query.sql, params, rawQueryConfig);
56+
return client(
57+
query.sql,
58+
params,
59+
token === undefined
60+
? rawQueryConfig
61+
: {
62+
...rawQueryConfig,
63+
authToken: token,
64+
},
65+
);
5066
}
5167

52-
const result = await client(query.sql, params, queryConfig);
68+
const result = await client(
69+
query.sql,
70+
params,
71+
token === undefined
72+
? queryConfig
73+
: {
74+
...queryConfig,
75+
authToken: token,
76+
},
77+
);
5378

5479
return this.mapResult(result);
5580
}
@@ -71,13 +96,26 @@ export class NeonHttpPreparedQuery<T extends PreparedQueryConfig> extends PgPrep
7196
all(placeholderValues: Record<string, unknown> | undefined = {}): Promise<T['all']> {
7297
const params = fillPlaceholders(this.query.params, placeholderValues);
7398
this.logger.logQuery(this.query.sql, params);
74-
return this.client(this.query.sql, params, rawQueryConfig).then((result) => result.rows);
99+
return this.client(
100+
this.query.sql,
101+
params,
102+
this.authToken === undefined ? rawQueryConfig : {
103+
...rawQueryConfig,
104+
authToken: this.authToken,
105+
},
106+
).then((result) => result.rows);
75107
}
76108

77-
values(placeholderValues: Record<string, unknown> | undefined = {}): Promise<T['values']> {
109+
values(placeholderValues: Record<string, unknown> | undefined): Promise<T['values']>;
110+
/** @internal */
111+
values(placeholderValues: Record<string, unknown> | undefined, token?: string): Promise<T['values']>;
112+
/** @internal */
113+
values(placeholderValues: Record<string, unknown> | undefined = {}, token?: string): Promise<T['values']> {
78114
const params = fillPlaceholders(this.query.params, placeholderValues);
79115
this.logger.logQuery(this.query.sql, params);
80-
return this.client(this.query.sql, params, { arrayMode: true, fullResults: true }).then((result) => result.rows);
116+
return this.client(this.query.sql, params, { arrayMode: true, fullResults: true, authToken: token }).then((
117+
result,
118+
) => result.rows);
81119
}
82120

83121
/** @internal */
@@ -125,7 +163,9 @@ export class NeonHttpSession<
125163
);
126164
}
127165

128-
async batch<U extends BatchItem<'pg'>, T extends Readonly<[U, ...U[]]>>(queries: T) {
166+
async batch<U extends BatchItem<'pg'>, T extends Readonly<[U, ...U[]]>>(
167+
queries: T,
168+
) {
129169
const preparedQueries: PreparedQuery[] = [];
130170
const builtQueries: NeonQueryPromise<any, true>[] = [];
131171

@@ -143,7 +183,7 @@ export class NeonHttpSession<
143183

144184
const batchResults = await this.client.transaction(builtQueries, queryConfig);
145185

146-
return batchResults.map((result, i) => preparedQueries[i]!.mapResult(result, true));
186+
return batchResults.map((result, i) => preparedQueries[i]!.mapResult(result, true)) as any;
147187
}
148188

149189
// change return type to QueryRows<true>
@@ -161,8 +201,12 @@ export class NeonHttpSession<
161201
return this.client(query, params, { arrayMode: false, fullResults: true });
162202
}
163203

164-
override async count(sql: SQL): Promise<number> {
165-
const res = await this.execute<{ rows: [{ count: string }] }>(sql);
204+
override async count(sql: SQL): Promise<number>;
205+
/** @internal */
206+
override async count(sql: SQL, token?: string): Promise<number>;
207+
/** @internal */
208+
override async count(sql: SQL, token?: string): Promise<number> {
209+
const res = await this.execute<{ rows: [{ count: string }] }>(sql, token);
166210

167211
return Number(
168212
res['rows'][0]['count'],

drizzle-orm/src/pg-core/db.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,8 @@ export class PgDatabase<
597597
return new PgRefreshMaterializedView(view, this.session, this.dialect);
598598
}
599599

600+
protected authToken?: string;
601+
600602
execute<TRow extends Record<string, unknown> = Record<string, unknown>>(
601603
query: SQLWrapper | string,
602604
): PgRaw<PgQueryResultKind<TQueryResult, TRow>> {
@@ -611,7 +613,7 @@ export class PgDatabase<
611613
false,
612614
);
613615
return new PgRaw(
614-
() => prepared.execute(),
616+
() => prepared.execute(undefined, this.authToken),
615617
sequel,
616618
builtQuery,
617619
(result) => prepared.mapResult(result, true),

drizzle-orm/src/pg-core/query-builders/count.ts

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ export class PgCountBuilder<
77
TSession extends PgSession<any, any, any>,
88
> extends SQL<number> implements Promise<number>, SQLWrapper {
99
private sql: SQL<number>;
10+
private token?: string;
1011

1112
static override readonly [entityKind] = 'PgCountBuilder';
1213
[Symbol.toStringTag] = 'PgCountBuilder';
@@ -46,19 +47,24 @@ export class PgCountBuilder<
4647
);
4748
}
4849

50+
/** @intrnal */
51+
setToken(token: string) {
52+
this.token = token;
53+
}
54+
4955
then<TResult1 = number, TResult2 = never>(
5056
onfulfilled?: ((value: number) => TResult1 | PromiseLike<TResult1>) | null | undefined,
5157
onrejected?: ((reason: any) => TResult2 | PromiseLike<TResult2>) | null | undefined,
5258
): Promise<TResult1 | TResult2> {
53-
return Promise.resolve(this.session.count(this.sql))
59+
return Promise.resolve(this.session.count(this.sql, this.token))
5460
.then(
5561
onfulfilled,
5662
onrejected,
5763
);
5864
}
5965

6066
catch(
61-
onRejected?: ((reason: any) => never | PromiseLike<never>) | null | undefined,
67+
onRejected?: ((reason: any) => any) | null | undefined,
6268
): Promise<number> {
6369
return this.then(undefined, onRejected);
6470
}

drizzle-orm/src/pg-core/query-builders/delete.ts

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,9 +232,16 @@ export class PgDeleteBase<
232232
return this._prepare(name);
233233
}
234234

235+
private authToken?: string;
236+
/** @internal */
237+
setToken(token: string) {
238+
this.authToken = token;
239+
return this;
240+
}
241+
235242
override execute: ReturnType<this['prepare']>['execute'] = (placeholderValues) => {
236243
return tracer.startActiveSpan('drizzle.operation', () => {
237-
return this._prepare().execute(placeholderValues);
244+
return this._prepare().execute(placeholderValues, this.authToken);
238245
});
239246
};
240247

drizzle-orm/src/pg-core/query-builders/insert.ts

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -63,16 +63,16 @@ export class PgInsertBuilder<
6363
private overridingSystemValue_?: boolean,
6464
) {}
6565

66-
overridingSystemValue(): Omit<PgInsertBuilder<TTable, TQueryResult, true>, 'overridingSystemValue'> {
67-
this.overridingSystemValue_ = true;
68-
return this as any;
66+
private authToken?: string;
67+
/** @internal */
68+
setToken(token: string) {
69+
this.authToken = token;
70+
return this;
6971
}
7072

71-
values(value: PgInsertValue<TTable, OverrideT>): PgInsertBase<TTable, TQueryResult>;
72-
values(values: PgInsertValue<TTable, OverrideT>[]): PgInsertBase<TTable, TQueryResult>;
73-
values(
74-
values: PgInsertValue<TTable, OverrideT> | PgInsertValue<TTable, OverrideT>[],
75-
): PgInsertBase<TTable, TQueryResult> {
73+
values(value: PgInsertValue<TTable>): PgInsertBase<TTable, TQueryResult>;
74+
values(values: PgInsertValue<TTable>[]): PgInsertBase<TTable, TQueryResult>;
75+
values(values: PgInsertValue<TTable> | PgInsertValue<TTable>[]): PgInsertBase<TTable, TQueryResult> {
7676
values = Array.isArray(values) ? values : [values];
7777
if (values.length === 0) {
7878
throw new Error('values() must be called with at least one value');
@@ -87,15 +87,25 @@ export class PgInsertBuilder<
8787
return result;
8888
});
8989

90-
return new PgInsertBase(
91-
this.table,
92-
mappedValues,
93-
this.session,
94-
this.dialect,
95-
this.withList,
96-
false,
97-
this.overridingSystemValue_,
98-
);
90+
return this.authToken === undefined
91+
? new PgInsertBase(
92+
this.table,
93+
mappedValues,
94+
this.session,
95+
this.dialect,
96+
this.withList,
97+
false,
98+
this.overridingSystemValue_,
99+
)
100+
: new PgInsertBase(
101+
this.table,
102+
mappedValues,
103+
this.session,
104+
this.dialect,
105+
this.withList,
106+
false,
107+
this.overridingSystemValue_,
108+
).setToken(this.authToken) as any;
99109
}
100110

101111
select(selectQuery: (qb: QueryBuilder) => PgInsertSelectQueryBuilder<TTable>): PgInsertBase<TTable, TQueryResult>;
@@ -385,9 +395,16 @@ export class PgInsertBase<
385395
return this._prepare(name);
386396
}
387397

398+
private authToken?: string;
399+
/** @internal */
400+
setToken(token: string) {
401+
this.authToken = token;
402+
return this;
403+
}
404+
388405
override execute: ReturnType<this['prepare']>['execute'] = (placeholderValues) => {
389406
return tracer.startActiveSpan('drizzle.operation', () => {
390-
return this._prepare().execute(placeholderValues);
407+
return this._prepare().execute(placeholderValues, this.authToken);
391408
});
392409
};
393410

drizzle-orm/src/pg-core/query-builders/query.ts

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,16 @@ export class PgRelationalQuery<TResult> extends QueryPromise<TResult>
142142
return this._toSQL().builtQuery;
143143
}
144144

145+
private authToken?: string;
146+
/** @internal */
147+
setToken(token: string) {
148+
this.authToken = token;
149+
return this;
150+
}
151+
145152
override execute(): Promise<TResult> {
146153
return tracer.startActiveSpan('drizzle.operation', () => {
147-
return this._prepare().execute();
154+
return this._prepare().execute(undefined, this.authToken);
148155
});
149156
}
150157
}

drizzle-orm/src/pg-core/query-builders/refresh-materialized-view.ts

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,16 @@ export class PgRefreshMaterializedView<TQueryResult extends PgQueryResultHKT>
9292
return this._prepare(name);
9393
}
9494

95+
private authToken?: string;
96+
/** @internal */
97+
setToken(token: string) {
98+
this.authToken = token;
99+
return this;
100+
}
101+
95102
execute: ReturnType<this['prepare']>['execute'] = (placeholderValues) => {
96103
return tracer.startActiveSpan('drizzle.operation', () => {
97-
return this._prepare().execute(placeholderValues);
104+
return this._prepare().execute(placeholderValues, this.authToken);
98105
});
99106
};
100107
}

0 commit comments

Comments
 (0)