Skip to content

Commit d7e3535

Browse files
authored
Support update ... from (Pg | SQLite) (#2963)
* Implement `update ... from` in PG * Add `update ... from` in SQLite * Lint and format * Fix type error * Fix SQLite type errors * Lint and format * Push merge changes
1 parent c31614a commit d7e3535

File tree

13 files changed

+1152
-137
lines changed

13 files changed

+1152
-137
lines changed

drizzle-orm/src/pg-core/dialect.ts

Lines changed: 78 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -169,18 +169,30 @@ export class PgDialect {
169169
}));
170170
}
171171

172-
buildUpdateQuery({ table, set, where, returning, withList }: PgUpdateConfig): SQL {
172+
buildUpdateQuery({ table, set, where, returning, withList, from, joins }: PgUpdateConfig): SQL {
173173
const withSql = this.buildWithCTE(withList);
174174

175+
const tableName = table[PgTable.Symbol.Name];
176+
const tableSchema = table[PgTable.Symbol.Schema];
177+
const origTableName = table[PgTable.Symbol.OriginalName];
178+
const alias = tableName === origTableName ? undefined : tableName;
179+
const tableSql = sql`${tableSchema ? sql`${sql.identifier(tableSchema)}.` : undefined}${
180+
sql.identifier(origTableName)
181+
}${alias && sql` ${sql.identifier(alias)}`}`;
182+
175183
const setSql = this.buildUpdateSet(table, set);
176184

185+
const fromSql = from && sql.join([sql.raw(' from '), this.buildFromTable(from)]);
186+
187+
const joinsSql = this.buildJoins(joins);
188+
177189
const returningSql = returning
178-
? sql` returning ${this.buildSelection(returning, { isSingleTable: true })}`
190+
? sql` returning ${this.buildSelection(returning, { isSingleTable: !from })}`
179191
: undefined;
180192

181193
const whereSql = where ? sql` where ${where}` : undefined;
182194

183-
return sql`${withSql}update ${table} set ${setSql}${whereSql}${returningSql}`;
195+
return sql`${withSql}update ${tableSql} set ${setSql}${fromSql}${joinsSql}${whereSql}${returningSql}`;
184196
}
185197

186198
/**
@@ -245,6 +257,67 @@ export class PgDialect {
245257
return sql.join(chunks);
246258
}
247259

260+
private buildJoins(joins: PgSelectJoinConfig[] | undefined): SQL | undefined {
261+
if (!joins || joins.length === 0) {
262+
return undefined;
263+
}
264+
265+
const joinsArray: SQL[] = [];
266+
267+
for (const [index, joinMeta] of joins.entries()) {
268+
if (index === 0) {
269+
joinsArray.push(sql` `);
270+
}
271+
const table = joinMeta.table;
272+
const lateralSql = joinMeta.lateral ? sql` lateral` : undefined;
273+
274+
if (is(table, PgTable)) {
275+
const tableName = table[PgTable.Symbol.Name];
276+
const tableSchema = table[PgTable.Symbol.Schema];
277+
const origTableName = table[PgTable.Symbol.OriginalName];
278+
const alias = tableName === origTableName ? undefined : joinMeta.alias;
279+
joinsArray.push(
280+
sql`${sql.raw(joinMeta.joinType)} join${lateralSql} ${
281+
tableSchema ? sql`${sql.identifier(tableSchema)}.` : undefined
282+
}${sql.identifier(origTableName)}${alias && sql` ${sql.identifier(alias)}`} on ${joinMeta.on}`,
283+
);
284+
} else if (is(table, View)) {
285+
const viewName = table[ViewBaseConfig].name;
286+
const viewSchema = table[ViewBaseConfig].schema;
287+
const origViewName = table[ViewBaseConfig].originalName;
288+
const alias = viewName === origViewName ? undefined : joinMeta.alias;
289+
joinsArray.push(
290+
sql`${sql.raw(joinMeta.joinType)} join${lateralSql} ${
291+
viewSchema ? sql`${sql.identifier(viewSchema)}.` : undefined
292+
}${sql.identifier(origViewName)}${alias && sql` ${sql.identifier(alias)}`} on ${joinMeta.on}`,
293+
);
294+
} else {
295+
joinsArray.push(
296+
sql`${sql.raw(joinMeta.joinType)} join${lateralSql} ${table} on ${joinMeta.on}`,
297+
);
298+
}
299+
if (index < joins.length - 1) {
300+
joinsArray.push(sql` `);
301+
}
302+
}
303+
304+
return sql.join(joinsArray);
305+
}
306+
307+
private buildFromTable(
308+
table: SQL | Subquery | PgViewBase | PgTable | undefined,
309+
): SQL | Subquery | PgViewBase | PgTable | undefined {
310+
if (is(table, Table) && table[Table.Symbol.OriginalName] !== table[Table.Symbol.Name]) {
311+
let fullName = sql`${sql.identifier(table[Table.Symbol.OriginalName])}`;
312+
if (table[Table.Symbol.Schema]) {
313+
fullName = sql`${sql.identifier(table[Table.Symbol.Schema]!)}.${fullName}`;
314+
}
315+
return sql`${fullName} ${sql.identifier(table[Table.Symbol.Name])}`;
316+
}
317+
318+
return table;
319+
}
320+
248321
buildSelectQuery(
249322
{
250323
withList,
@@ -300,60 +373,9 @@ export class PgDialect {
300373

301374
const selection = this.buildSelection(fieldsList, { isSingleTable });
302375

303-
const tableSql = (() => {
304-
if (is(table, Table) && table[Table.Symbol.OriginalName] !== table[Table.Symbol.Name]) {
305-
let fullName = sql`${sql.identifier(table[Table.Symbol.OriginalName])}`;
306-
if (table[Table.Symbol.Schema]) {
307-
fullName = sql`${sql.identifier(table[Table.Symbol.Schema]!)}.${fullName}`;
308-
}
309-
return sql`${fullName} ${sql.identifier(table[Table.Symbol.Name])}`;
310-
}
311-
312-
return table;
313-
})();
314-
315-
const joinsArray: SQL[] = [];
316-
317-
if (joins) {
318-
for (const [index, joinMeta] of joins.entries()) {
319-
if (index === 0) {
320-
joinsArray.push(sql` `);
321-
}
322-
const table = joinMeta.table;
323-
const lateralSql = joinMeta.lateral ? sql` lateral` : undefined;
324-
325-
if (is(table, PgTable)) {
326-
const tableName = table[PgTable.Symbol.Name];
327-
const tableSchema = table[PgTable.Symbol.Schema];
328-
const origTableName = table[PgTable.Symbol.OriginalName];
329-
const alias = tableName === origTableName ? undefined : joinMeta.alias;
330-
joinsArray.push(
331-
sql`${sql.raw(joinMeta.joinType)} join${lateralSql} ${
332-
tableSchema ? sql`${sql.identifier(tableSchema)}.` : undefined
333-
}${sql.identifier(origTableName)}${alias && sql` ${sql.identifier(alias)}`} on ${joinMeta.on}`,
334-
);
335-
} else if (is(table, View)) {
336-
const viewName = table[ViewBaseConfig].name;
337-
const viewSchema = table[ViewBaseConfig].schema;
338-
const origViewName = table[ViewBaseConfig].originalName;
339-
const alias = viewName === origViewName ? undefined : joinMeta.alias;
340-
joinsArray.push(
341-
sql`${sql.raw(joinMeta.joinType)} join${lateralSql} ${
342-
viewSchema ? sql`${sql.identifier(viewSchema)}.` : undefined
343-
}${sql.identifier(origViewName)}${alias && sql` ${sql.identifier(alias)}`} on ${joinMeta.on}`,
344-
);
345-
} else {
346-
joinsArray.push(
347-
sql`${sql.raw(joinMeta.joinType)} join${lateralSql} ${table} on ${joinMeta.on}`,
348-
);
349-
}
350-
if (index < joins.length - 1) {
351-
joinsArray.push(sql` `);
352-
}
353-
}
354-
}
376+
const tableSql = this.buildFromTable(table);
355377

356-
const joinsSql = sql.join(joinsArray);
378+
const joinsSql = this.buildJoins(joins);
357379

358380
const whereSql = where ? sql` where ${where}` : undefined;
359381

drizzle-orm/src/pg-core/query-builders/select.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,11 @@ import type {
3434
LockConfig,
3535
LockStrength,
3636
PgCreateSetOperatorFn,
37-
PgJoinFn,
3837
PgSelectConfig,
3938
PgSelectDynamic,
4039
PgSelectHKT,
4140
PgSelectHKTBase,
41+
PgSelectJoinFn,
4242
PgSelectPrepare,
4343
PgSelectWithout,
4444
PgSetOperatorExcludedMethods,
@@ -194,7 +194,7 @@ export abstract class PgSelectQueryBuilderBase<
194194

195195
private createJoin<TJoinType extends JoinType>(
196196
joinType: TJoinType,
197-
): PgJoinFn<this, TDynamic, TJoinType> {
197+
): PgSelectJoinFn<this, TDynamic, TJoinType> {
198198
return (
199199
table: PgTable | Subquery | PgViewBase | SQL,
200200
on: ((aliases: TSelection) => SQL | undefined) | SQL | undefined,

drizzle-orm/src/pg-core/query-builders/select.types.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ export interface PgSelectConfig {
7979
}[];
8080
}
8181

82-
export type PgJoin<
82+
export type PgSelectJoin<
8383
T extends AnyPgSelectQueryBuilder,
8484
TDynamic extends boolean,
8585
TJoinType extends JoinType,
@@ -108,7 +108,7 @@ export type PgJoin<
108108
>
109109
: never;
110110

111-
export type PgJoinFn<
111+
export type PgSelectJoinFn<
112112
T extends AnyPgSelectQueryBuilder,
113113
TDynamic extends boolean,
114114
TJoinType extends JoinType,
@@ -118,7 +118,7 @@ export type PgJoinFn<
118118
>(
119119
table: TJoinedTable,
120120
on: ((aliases: T['_']['selection']) => SQL | undefined) | SQL | undefined,
121-
) => PgJoin<T, TDynamic, TJoinType, TJoinedTable, TJoinedName>;
121+
) => PgSelectJoin<T, TDynamic, TJoinType, TJoinedTable, TJoinedName>;
122122

123123
export type SelectedFieldsFlat = SelectedFieldsFlatBase<PgColumn>;
124124

0 commit comments

Comments
 (0)