Skip to content

Commit 76d8d4d

Browse files
authored
Propagate context in case of an alias also (#93)
* add wraps * add test * work around alias * changeset
1 parent 61d637c commit 76d8d4d

File tree

6 files changed

+74
-10
lines changed

6 files changed

+74
-10
lines changed

.changeset/tiny-ties-sell.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'druid-query-toolkit': patch
3+
---
4+
5+
Propogate context in case of an alias also

src/sql/sql-case/sql-case.spec.ts

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,32 @@ describe('CaseExpression', () => {
2626
backAndForth(sql, SqlCase);
2727
});
2828

29+
it('.ifThenElse', () => {
30+
// Test with condition, then, and else
31+
const condition = SqlExpression.parse('x > 5');
32+
const thenExpr = SqlExpression.parse('"result is true"');
33+
const elseExpr = SqlExpression.parse('"result is false"');
34+
35+
expect(SqlCase.ifThenElse(condition, thenExpr, elseExpr).toString()).toEqual(
36+
'CASE WHEN x > 5 THEN "result is true" ELSE "result is false" END',
37+
);
38+
39+
// Test with literal values
40+
expect(SqlCase.ifThenElse(condition, 'yes', 'no').toString()).toEqual(
41+
`CASE WHEN x > 5 THEN 'yes' ELSE 'no' END`,
42+
);
43+
44+
// Test numeric literals
45+
expect(SqlCase.ifThenElse(condition, 1, 0).toString()).toEqual(
46+
'CASE WHEN x > 5 THEN 1 ELSE 0 END',
47+
);
48+
49+
// Test without else expression
50+
expect(SqlCase.ifThenElse(condition, thenExpr).toString()).toEqual(
51+
'CASE WHEN x > 5 THEN "result is true" END',
52+
);
53+
});
54+
2955
it('caseless CASE Expression', () => {
3056
const sql = `CASE WHEN B THEN C END`;
3157

src/sql/sql-case/sql-case.ts

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import type { SqlBaseValue, SqlTypeDesignator, Substitutor } from '../sql-base';
1616
import { SqlBase } from '../sql-base';
1717
import { SqlExpression } from '../sql-expression';
18+
import type { LiteralValue } from '../sql-literal/sql-literal';
1819
import { SeparatedArray, SPACE } from '../utils';
1920

2021
import { SqlWhenThenPart } from './sql-when-then-part';
@@ -34,14 +35,15 @@ export class SqlCase extends SqlExpression {
3435

3536
static ifThenElse(
3637
conditionExpression: SqlExpression,
37-
thenExpression: SqlExpression,
38-
elseExpression?: SqlExpression,
38+
thenExpression: SqlExpression | LiteralValue,
39+
elseExpression?: SqlExpression | LiteralValue,
3940
) {
4041
return new SqlCase({
4142
whenThenParts: SeparatedArray.fromSingleValue(
4243
SqlWhenThenPart.create(conditionExpression, thenExpression),
4344
),
44-
elseExpression,
45+
elseExpression:
46+
typeof elseExpression !== 'undefined' ? SqlExpression.wrap(elseExpression) : undefined,
4547
});
4648
}
4749

src/sql/sql-case/sql-when-then-part.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import type { SqlBaseValue, SqlTypeDesignator, Substitutor } from '../sql-base';
1616
import { SqlBase } from '../sql-base';
1717
import { SqlExpression } from '../sql-expression';
18+
import type { LiteralValue } from '../sql-literal/sql-literal';
1819
import { SeparatedArray, Separator } from '../utils';
1920

2021
export interface SqlWhenThenPartValue extends SqlBaseValue {
@@ -30,13 +31,13 @@ export class SqlWhenThenPart extends SqlBase {
3031

3132
static create(
3233
whenExpressions: SeparatedArray<SqlExpression> | SqlExpression[] | SqlExpression,
33-
thenExpression: SqlExpression,
34+
thenExpression: SqlExpression | LiteralValue,
3435
): SqlWhenThenPart {
3536
return new SqlWhenThenPart({
3637
whenExpressions: SeparatedArray.fromArray(
3738
whenExpressions instanceof SqlExpression ? [whenExpressions] : whenExpressions,
3839
),
39-
thenExpression: SqlExpression.verify(thenExpression),
40+
thenExpression: SqlExpression.wrap(thenExpression),
4041
});
4142
}
4243

src/sql/sql-query/sql-query.spec.ts

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ describe('SqlQuery', () => {
183183
`);
184184
});
185185

186-
it('works in advanced case', () => {
186+
it('works with context', () => {
187187
const query = SqlQuery.selectStarFrom(
188188
SqlQuery.selectStarFrom(SqlTable.create('lol')).changeContext({ a: 1 }),
189189
)
@@ -209,6 +209,33 @@ describe('SqlQuery', () => {
209209
WHERE channel = '#en.wikipedia'
210210
`);
211211
});
212+
213+
it('works with context and alias', () => {
214+
const query = SqlQuery.selectStarFrom(
215+
SqlQuery.selectStarFrom(SqlTable.create('lol')).changeContext({ a: 1 }).as('t'),
216+
)
217+
.changeSelectExpressions([
218+
SqlColumn.create('channel'),
219+
SqlColumn.create('page'),
220+
SqlColumn.create('user'),
221+
SqlColumn.create('as'),
222+
])
223+
.changeWhereExpression(SqlExpression.parse(`channel = '#en.wikipedia'`));
224+
225+
expect(String(query)).toEqual(sane`
226+
SET a = 1;
227+
SELECT
228+
"channel",
229+
"page",
230+
"user",
231+
"as"
232+
FROM (
233+
SELECT *
234+
FROM "lol"
235+
) AS "t"
236+
WHERE channel = '#en.wikipedia'
237+
`);
238+
});
212239
});
213240

214241
describe('.from', () => {

src/sql/sql-query/sql-query.ts

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,13 @@ export class SqlQuery extends SqlExpression {
102102
static from(from: string | SqlExpression | SqlFromClause): SqlQuery {
103103
// Extract context from the inner query if given
104104
let contextStatements: SeparatedArray<SqlSetStatement> | undefined;
105-
if (from instanceof SqlQuery || from instanceof SqlWithQuery) {
106-
contextStatements = from.contextStatements;
107-
if (contextStatements) {
108-
from = from.changeContextStatements(undefined);
105+
if (from instanceof SqlExpression) {
106+
const underlyingFrom = from.getUnderlyingExpression();
107+
if (underlyingFrom instanceof SqlQuery || underlyingFrom instanceof SqlWithQuery) {
108+
contextStatements = underlyingFrom.contextStatements;
109+
if (contextStatements) {
110+
from = from.changeUnderlyingExpression(underlyingFrom.changeContextStatements(undefined));
111+
}
109112
}
110113
}
111114

0 commit comments

Comments
 (0)