@@ -169,18 +169,30 @@ export class PgDialect {
169169 } ) ) ;
170170 }
171171
172- buildUpdateQuery ( { table, set, where, returning, withList } : PgUpdateConfig ) : SQL {
172+ buildUpdateQuery ( { table, set, where, returning, withList, from , joins } : PgUpdateConfig ) : SQL {
173173 const withSql = this . buildWithCTE ( withList ) ;
174174
175+ const tableName = table [ PgTable . Symbol . Name ] ;
176+ const tableSchema = table [ PgTable . Symbol . Schema ] ;
177+ const origTableName = table [ PgTable . Symbol . OriginalName ] ;
178+ const alias = tableName === origTableName ? undefined : tableName ;
179+ const tableSql = sql `${ tableSchema ? sql `${ sql . identifier ( tableSchema ) } .` : undefined } ${
180+ sql . identifier ( origTableName )
181+ } ${ alias && sql ` ${ sql . identifier ( alias ) } ` } `;
182+
175183 const setSql = this . buildUpdateSet ( table , set ) ;
176184
185+ const fromSql = from && sql . join ( [ sql . raw ( ' from ' ) , this . buildFromTable ( from ) ] ) ;
186+
187+ const joinsSql = this . buildJoins ( joins ) ;
188+
177189 const returningSql = returning
178- ? sql ` returning ${ this . buildSelection ( returning , { isSingleTable : true } ) } `
190+ ? sql ` returning ${ this . buildSelection ( returning , { isSingleTable : ! from } ) } `
179191 : undefined ;
180192
181193 const whereSql = where ? sql ` where ${ where } ` : undefined ;
182194
183- return sql `${ withSql } update ${ table } set ${ setSql } ${ whereSql } ${ returningSql } ` ;
195+ return sql `${ withSql } update ${ tableSql } set ${ setSql } ${ fromSql } ${ joinsSql } ${ whereSql } ${ returningSql } ` ;
184196 }
185197
186198 /**
@@ -245,6 +257,67 @@ export class PgDialect {
245257 return sql . join ( chunks ) ;
246258 }
247259
260+ private buildJoins ( joins : PgSelectJoinConfig [ ] | undefined ) : SQL | undefined {
261+ if ( ! joins || joins . length === 0 ) {
262+ return undefined ;
263+ }
264+
265+ const joinsArray : SQL [ ] = [ ] ;
266+
267+ for ( const [ index , joinMeta ] of joins . entries ( ) ) {
268+ if ( index === 0 ) {
269+ joinsArray . push ( sql ` ` ) ;
270+ }
271+ const table = joinMeta . table ;
272+ const lateralSql = joinMeta . lateral ? sql ` lateral` : undefined ;
273+
274+ if ( is ( table , PgTable ) ) {
275+ const tableName = table [ PgTable . Symbol . Name ] ;
276+ const tableSchema = table [ PgTable . Symbol . Schema ] ;
277+ const origTableName = table [ PgTable . Symbol . OriginalName ] ;
278+ const alias = tableName === origTableName ? undefined : joinMeta . alias ;
279+ joinsArray . push (
280+ sql `${ sql . raw ( joinMeta . joinType ) } join${ lateralSql } ${
281+ tableSchema ? sql `${ sql . identifier ( tableSchema ) } .` : undefined
282+ } ${ sql . identifier ( origTableName ) } ${ alias && sql ` ${ sql . identifier ( alias ) } ` } on ${ joinMeta . on } `,
283+ ) ;
284+ } else if ( is ( table , View ) ) {
285+ const viewName = table [ ViewBaseConfig ] . name ;
286+ const viewSchema = table [ ViewBaseConfig ] . schema ;
287+ const origViewName = table [ ViewBaseConfig ] . originalName ;
288+ const alias = viewName === origViewName ? undefined : joinMeta . alias ;
289+ joinsArray . push (
290+ sql `${ sql . raw ( joinMeta . joinType ) } join${ lateralSql } ${
291+ viewSchema ? sql `${ sql . identifier ( viewSchema ) } .` : undefined
292+ } ${ sql . identifier ( origViewName ) } ${ alias && sql ` ${ sql . identifier ( alias ) } ` } on ${ joinMeta . on } `,
293+ ) ;
294+ } else {
295+ joinsArray . push (
296+ sql `${ sql . raw ( joinMeta . joinType ) } join${ lateralSql } ${ table } on ${ joinMeta . on } ` ,
297+ ) ;
298+ }
299+ if ( index < joins . length - 1 ) {
300+ joinsArray . push ( sql ` ` ) ;
301+ }
302+ }
303+
304+ return sql . join ( joinsArray ) ;
305+ }
306+
307+ private buildFromTable (
308+ table : SQL | Subquery | PgViewBase | PgTable | undefined ,
309+ ) : SQL | Subquery | PgViewBase | PgTable | undefined {
310+ if ( is ( table , Table ) && table [ Table . Symbol . OriginalName ] !== table [ Table . Symbol . Name ] ) {
311+ let fullName = sql `${ sql . identifier ( table [ Table . Symbol . OriginalName ] ) } ` ;
312+ if ( table [ Table . Symbol . Schema ] ) {
313+ fullName = sql `${ sql . identifier ( table [ Table . Symbol . Schema ] ! ) } .${ fullName } ` ;
314+ }
315+ return sql `${ fullName } ${ sql . identifier ( table [ Table . Symbol . Name ] ) } ` ;
316+ }
317+
318+ return table ;
319+ }
320+
248321 buildSelectQuery (
249322 {
250323 withList,
@@ -300,60 +373,9 @@ export class PgDialect {
300373
301374 const selection = this . buildSelection ( fieldsList , { isSingleTable } ) ;
302375
303- const tableSql = ( ( ) => {
304- if ( is ( table , Table ) && table [ Table . Symbol . OriginalName ] !== table [ Table . Symbol . Name ] ) {
305- let fullName = sql `${ sql . identifier ( table [ Table . Symbol . OriginalName ] ) } ` ;
306- if ( table [ Table . Symbol . Schema ] ) {
307- fullName = sql `${ sql . identifier ( table [ Table . Symbol . Schema ] ! ) } .${ fullName } ` ;
308- }
309- return sql `${ fullName } ${ sql . identifier ( table [ Table . Symbol . Name ] ) } ` ;
310- }
311-
312- return table ;
313- } ) ( ) ;
314-
315- const joinsArray : SQL [ ] = [ ] ;
316-
317- if ( joins ) {
318- for ( const [ index , joinMeta ] of joins . entries ( ) ) {
319- if ( index === 0 ) {
320- joinsArray . push ( sql ` ` ) ;
321- }
322- const table = joinMeta . table ;
323- const lateralSql = joinMeta . lateral ? sql ` lateral` : undefined ;
324-
325- if ( is ( table , PgTable ) ) {
326- const tableName = table [ PgTable . Symbol . Name ] ;
327- const tableSchema = table [ PgTable . Symbol . Schema ] ;
328- const origTableName = table [ PgTable . Symbol . OriginalName ] ;
329- const alias = tableName === origTableName ? undefined : joinMeta . alias ;
330- joinsArray . push (
331- sql `${ sql . raw ( joinMeta . joinType ) } join${ lateralSql } ${
332- tableSchema ? sql `${ sql . identifier ( tableSchema ) } .` : undefined
333- } ${ sql . identifier ( origTableName ) } ${ alias && sql ` ${ sql . identifier ( alias ) } ` } on ${ joinMeta . on } `,
334- ) ;
335- } else if ( is ( table , View ) ) {
336- const viewName = table [ ViewBaseConfig ] . name ;
337- const viewSchema = table [ ViewBaseConfig ] . schema ;
338- const origViewName = table [ ViewBaseConfig ] . originalName ;
339- const alias = viewName === origViewName ? undefined : joinMeta . alias ;
340- joinsArray . push (
341- sql `${ sql . raw ( joinMeta . joinType ) } join${ lateralSql } ${
342- viewSchema ? sql `${ sql . identifier ( viewSchema ) } .` : undefined
343- } ${ sql . identifier ( origViewName ) } ${ alias && sql ` ${ sql . identifier ( alias ) } ` } on ${ joinMeta . on } `,
344- ) ;
345- } else {
346- joinsArray . push (
347- sql `${ sql . raw ( joinMeta . joinType ) } join${ lateralSql } ${ table } on ${ joinMeta . on } ` ,
348- ) ;
349- }
350- if ( index < joins . length - 1 ) {
351- joinsArray . push ( sql ` ` ) ;
352- }
353- }
354- }
376+ const tableSql = this . buildFromTable ( table ) ;
355377
356- const joinsSql = sql . join ( joinsArray ) ;
378+ const joinsSql = this . buildJoins ( joins ) ;
357379
358380 const whereSql = where ? sql ` where ${ where } ` : undefined ;
359381
0 commit comments