diff --git a/packages/cubejs-api-gateway/package.json b/packages/cubejs-api-gateway/package.json index 467bf436eec34..b39b0c2aa1556 100644 --- a/packages/cubejs-api-gateway/package.json +++ b/packages/cubejs-api-gateway/package.json @@ -30,6 +30,7 @@ "@cubejs-backend/native": "1.2.30", "@cubejs-backend/shared": "1.2.30", "@ungap/structured-clone": "^0.3.4", + "assert-never": "^1.4.0", "body-parser": "^1.19.0", "chrono-node": "^2.6.2", "express": "^4.21.1", diff --git a/packages/cubejs-api-gateway/src/gateway.ts b/packages/cubejs-api-gateway/src/gateway.ts index 28972b895a9fe..1dedb76efd759 100644 --- a/packages/cubejs-api-gateway/src/gateway.ts +++ b/packages/cubejs-api-gateway/src/gateway.ts @@ -1,5 +1,6 @@ /* eslint-disable no-restricted-syntax */ import * as stream from 'stream'; +import { assertNever } from 'assert-never'; import jwt, { Algorithm as JWTAlgorithm } from 'jsonwebtoken'; import R from 'ramda'; import bodyParser from 'body-parser'; @@ -83,6 +84,7 @@ import { normalizeQueryCancelPreAggregations, normalizeQueryPreAggregationPreview, normalizeQueryPreAggregations, + parseInputMemberExpression, preAggsJobsRequestSchema, remapToQueryAdapterFormat, } from './query'; @@ -1392,30 +1394,46 @@ class ApiGateway { } private parseMemberExpression(memberExpression: string): string | ParsedMemberExpression { - try { - if (memberExpression.startsWith('{')) { - const obj = JSON.parse(memberExpression); - const args = obj.cube_params; - args.push(`return \`${obj.expr}\``); + if (memberExpression.startsWith('{')) { + const obj = parseInputMemberExpression(JSON.parse(memberExpression)); + let expression: ParsedMemberExpression['expression']; + switch (obj.expr.type) { + case 'SqlFunction': + expression = [ + ...obj.expr.cubeParams, + `return \`${obj.expr.sql}\``, + ]; + break; + case 'PatchMeasure': + expression = { + type: 'PatchMeasure', + sourceMeasure: obj.expr.sourceMeasure, + replaceAggregationType: obj.expr.replaceAggregationType, + addFilters: obj.expr.addFilters.map(filter => [ + ...filter.cubeParams, + `return \`${filter.sql}\``, + ]), + }; + break; + default: + assertNever(obj.expr); + } - const groupingSet = obj.grouping_set ? { - groupType: obj.grouping_set.group_type, - id: obj.grouping_set.id, - subId: obj.grouping_set.sub_id ? obj.grouping_set.sub_id : undefined - } : undefined; + const groupingSet = obj.groupingSet ? { + groupType: obj.groupingSet.groupType, + id: obj.groupingSet.id, + subId: obj.groupingSet.subId ? obj.groupingSet.subId : undefined + } : undefined; - return { - cubeName: obj.cube_name, - name: obj.alias, - expressionName: obj.alias, - expression: args, - definition: memberExpression, - groupingSet, - }; - } else { - return memberExpression; - } - } catch { + return { + cubeName: obj.cubeName, + name: obj.alias, + expressionName: obj.alias, + expression, + definition: memberExpression, + groupingSet, + }; + } else { return memberExpression; } } @@ -1433,14 +1451,31 @@ class ApiGateway { }; } - private evalMemberExpression(memberExpression: MemberExpression | ParsedMemberExpression): string | MemberExpression { - const expression = Array.isArray(memberExpression.expression) ? - Function.constructor.apply(null, memberExpression.expression) : memberExpression.expression; + private evalMemberExpression(memberExpression: MemberExpression | ParsedMemberExpression): MemberExpression | ParsedMemberExpression { + if (typeof memberExpression.expression === 'function') { + return memberExpression; + } - return { - ...memberExpression, - expression, - }; + if (Array.isArray(memberExpression.expression)) { + return { + ...memberExpression, + expression: Function.constructor.apply(null, memberExpression.expression), + }; + } + + if (memberExpression.expression.type === 'PatchMeasure') { + return { + ...memberExpression, + expression: { + ...memberExpression.expression, + addFilters: memberExpression.expression.addFilters.map(filter => ({ + sql: Function.constructor.apply(null, filter), + })), + } + }; + } + + throw new Error(`Unexpected member expression to evaluate: ${memberExpression}`); } public async sqlGenerators({ context, res }: { context: RequestContext, res: ResponseResultFn }) { diff --git a/packages/cubejs-api-gateway/src/query.js b/packages/cubejs-api-gateway/src/query.js index a6936f43dcf2c..147aee1b98020 100644 --- a/packages/cubejs-api-gateway/src/query.js +++ b/packages/cubejs-api-gateway/src/query.js @@ -39,11 +39,31 @@ const getPivotQuery = (queryType, queries) => { return pivotQuery; }; +const parsedPatchMeasureFilterExpression = Joi.array().items(Joi.string()); + +const evaluatedPatchMeasureFilterExpression = Joi.object().keys({ + sql: Joi.func().required(), +}); + +const parsedPatchMeasureExpression = Joi.object().keys({ + type: Joi.valid('PatchMeasure').required(), + sourceMeasure: Joi.string().required(), + replaceAggregationType: Joi.string().allow(null).required(), + addFilters: Joi.array().items(parsedPatchMeasureFilterExpression).required(), +}); + +const evaluatedPatchMeasureExpression = parsedPatchMeasureExpression.keys({ + addFilters: Joi.array().items(evaluatedPatchMeasureFilterExpression).required(), +}); + const id = Joi.string().regex(/^[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+$/); const idOrMemberExpressionName = Joi.string().regex(/^[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+$|^[a-zA-Z0-9_]+$/); const dimensionWithTime = Joi.string().regex(/^[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+(\.[a-zA-Z0-9_]+)?$/); const parsedMemberExpression = Joi.object().keys({ - expression: Joi.array().items(Joi.string()).min(1).required(), + expression: Joi.alternatives( + Joi.array().items(Joi.string()).min(1), + parsedPatchMeasureExpression, + ).required(), cubeName: Joi.string().required(), name: Joi.string().required(), expressionName: Joi.string(), @@ -55,7 +75,43 @@ const parsedMemberExpression = Joi.object().keys({ }) }); const memberExpression = parsedMemberExpression.keys({ - expression: Joi.func().required(), + expression: Joi.alternatives( + Joi.func().required(), + evaluatedPatchMeasureExpression, + ).required(), +}); + +const inputSqlFunction = Joi.object().keys({ + cubeParams: Joi.array().items(Joi.string()).required(), + sql: Joi.string().required(), +}); + +// This should be aligned with cubesql side +const inputMemberExpressionSqlFunction = inputSqlFunction.keys({ + type: Joi.valid('SqlFunction').required(), +}); + +// This should be aligned with cubesql side +const inputMemberExpressionPatchMeasure = Joi.object().keys({ + type: Joi.valid('PatchMeasure').required(), + sourceMeasure: Joi.string().required(), + replaceAggregationType: Joi.string().allow(null).required(), + addFilters: Joi.array().items(inputSqlFunction).required(), +}); + +// This should be aligned with cubesql side +const inputMemberExpression = Joi.object().keys({ + cubeName: Joi.string().required(), + alias: Joi.string().required(), + expr: Joi.alternatives( + inputMemberExpressionSqlFunction, + inputMemberExpressionPatchMeasure, + ), + groupingSet: Joi.object().keys({ + groupType: Joi.valid('Rollup', 'Cube').required(), + id: Joi.number().required(), + subId: Joi.number().allow(null), + }).allow(null) }); const operators = [ @@ -215,6 +271,20 @@ const normalizeQueryFilters = (filter) => ( }) ); +/** + * Parse incoming member expression + * @param {unknown} expression + * @throws {import('./UserError').UserError} + * @returns {import('./types/query').InputMemberExpression} + */ +function parseInputMemberExpression(expression) { + const { error } = inputMemberExpression.validate(expression); + if (error) { + throw new UserError(`Invalid member expression format: ${error.message || error.toString()}`); + } + return expression; +} + /** * Normalize incoming network query. * @param {Query} query @@ -384,5 +454,6 @@ export { normalizeQueryPreAggregations, normalizeQueryPreAggregationPreview, normalizeQueryCancelPreAggregations, + parseInputMemberExpression, remapToQueryAdapterFormat, }; diff --git a/packages/cubejs-api-gateway/src/types/query.ts b/packages/cubejs-api-gateway/src/types/query.ts index 8c1f1c03c5cf7..b26b39c09a9db 100644 --- a/packages/cubejs-api-gateway/src/types/query.ts +++ b/packages/cubejs-api-gateway/src/types/query.ts @@ -38,14 +38,34 @@ type LogicalOrFilter = { or: (QueryFilter | LogicalAndFilter)[] }; +export type GroupingSetType = 'Rollup' | 'Cube'; + type GroupingSet = { - groupType: string, + groupType: GroupingSetType, id: number, subId?: null | number }; +export type EvalPatchMeasureFilterExpression = { + sql: Function, +}; + +export type PatchMeasureExpression = { + type: 'PatchMeasure', + sourceMeasure: string, + replaceAggregationType: string | null, + addFilters: Array>, +}; + +export type EvalPatchMeasureExpression = { + type: 'PatchMeasure', + sourceMeasure: string, + replaceAggregationType: string | null, + addFilters: Array, +}; + type ParsedMemberExpression = { - expression: string[]; + expression: string[] | PatchMeasureExpression; cubeName: string; name: string; expressionName: string; @@ -54,7 +74,33 @@ type ParsedMemberExpression = { }; type MemberExpression = Omit & { - expression: Function; + expression: Function | EvalPatchMeasureExpression; +}; + +type InputSqlFunction = { + cubeParams: Array, + sql: string, +}; + +export type InputMemberExpressionSqlFunction = { + type: 'SqlFunction' +} & InputSqlFunction; + +export type InputMemberExpressionPatchMeasure = { + type: 'PatchMeasure', + sourceMeasure: string, + replaceAggregationType: string | null, + addFilters: Array, +}; + +export type InputMemberExpressionExpr = InputMemberExpressionSqlFunction | InputMemberExpressionPatchMeasure; + +// This should be aligned with cubesql side +export type InputMemberExpression = { + cubeName: string, + alias: string, + expr: InputMemberExpressionExpr, + groupingSet: GroupingSet | null, }; /** diff --git a/packages/cubejs-api-gateway/test/index.test.ts b/packages/cubejs-api-gateway/test/index.test.ts index 2153442e814ad..2a8d5f03bab0f 100644 --- a/packages/cubejs-api-gateway/test/index.test.ts +++ b/packages/cubejs-api-gateway/test/index.test.ts @@ -771,6 +771,21 @@ describe('API Gateway', () => { }); describe('sql api member expressions evaluations', () => { + const query = { + measures: [ + // eslint-disable-next-line no-template-curly-in-string + '{"cubeName":"sales","alias":"sum_sales_line_i","expr":{"type":"SqlFunction","cubeParams":["sales"],"sql":"SUM(${sales.line_items_price})"},"groupingSet":null}' + ], + dimensions: [ + // eslint-disable-next-line no-template-curly-in-string + '{"cubeName":"sales","alias":"users_age","expr":{"type":"SqlFunction","cubeParams":["sales"],"sql":"${sales.users_age}"},"groupingSet":null}', + // eslint-disable-next-line no-template-curly-in-string + '{"cubeName":"sales","alias":"cast_sales_users","expr":{"type":"SqlFunction","cubeParams":["sales"],"sql":"CAST(${sales.users_first_name} AS TEXT)"},"groupingSet":null}' + ], + segments: [], + order: [] + }; + test('throw error if expressions are not allowed', async () => { const { apiGateway } = await createApiGateway(); const request: QueryRequest = { @@ -779,20 +794,7 @@ describe('API Gateway', () => { const errorMessage = message as { error: string }; expect(errorMessage.error).toEqual('Error: Expressions are not allowed in this context'); }, - query: { - measures: [ - // eslint-disable-next-line no-template-curly-in-string - '{"cube_name":"sales","alias":"sum_sales_line_i","cube_params":["sales"],"expr":"SUM(${sales.line_items_price})","grouping_set":null}' - ], - dimensions: [ - // eslint-disable-next-line no-template-curly-in-string - '{"cube_name":"sales","alias":"users_age","cube_params":["sales"],"expr":"${sales.users_age}","grouping_set":null}', - // eslint-disable-next-line no-template-curly-in-string - '{"cube_name":"sales","alias":"cast_sales_users","cube_params":["sales"],"expr":"CAST(${sales.users_first_name} AS TEXT)","grouping_set":null}' - ], - segments: [], - order: [] - }, + query, expressionParams: [], exportAnnotatedSql: true, memberExpressions: false, @@ -820,20 +822,7 @@ describe('API Gateway', () => { res(message) { expect(message.hasOwnProperty('sql')).toBe(true); }, - query: { - measures: [ - // eslint-disable-next-line no-template-curly-in-string - '{"cube_name":"sales","alias":"sum_sales_line_i","cube_params":["sales"],"expr":"SUM(${sales.line_items_price})","grouping_set":null}' - ], - dimensions: [ - // eslint-disable-next-line no-template-curly-in-string - '{"cube_name":"sales","alias":"users_age","cube_params":["sales"],"expr":"${sales.users_age}","grouping_set":null}', - // eslint-disable-next-line no-template-curly-in-string - '{"cube_name":"sales","alias":"cast_sales_users","cube_params":["sales"],"expr":"CAST(${sales.users_first_name} AS TEXT)","grouping_set":null}' - ], - segments: [], - order: [] - }, + query, expressionParams: [], exportAnnotatedSql: true, memberExpressions: true, diff --git a/packages/cubejs-schema-compiler/src/adapter/BaseMeasure.ts b/packages/cubejs-schema-compiler/src/adapter/BaseMeasure.ts index ea010dc0fd711..1e44283cdbca1 100644 --- a/packages/cubejs-schema-compiler/src/adapter/BaseMeasure.ts +++ b/packages/cubejs-schema-compiler/src/adapter/BaseMeasure.ts @@ -1,5 +1,6 @@ import { UserError } from '../compiler/UserError'; import type { BaseQuery } from './BaseQuery'; +import { MeasureDefinition } from "../compiler/CubeEvaluator"; export class BaseMeasure { public readonly expression: any; @@ -10,6 +11,100 @@ export class BaseMeasure { public readonly isMemberExpression: boolean = false; + protected readonly patchedMeasure: MeasureDefinition | null = null; + + protected preparePatchedMeasure(sourceMeasure: string, newMeasureType: string | null, addFilters: Array<{sql: Function}>): MeasureDefinition { + const source = this.query.cubeEvaluator.measureByPath(sourceMeasure); + + let resultMeasureType = source.type; + if (newMeasureType !== null) { + switch (source.type) { + case 'sum': + case 'avg': + case 'min': + case 'max': + switch (newMeasureType) { + case 'sum': + case 'avg': + case 'min': + case 'max': + case 'count_distinct': + case 'count_distinct_approx': + // Can change from avg/... to count_distinct + // Latter does not care what input value is + // ok, do nothing + break; + default: + throw new UserError( + `Unsupported measure type replacement for ${sourceMeasure}: ${source.type} => ${newMeasureType}` + ); + } + break; + case 'count_distinct': + case 'count_distinct_approx': + switch (newMeasureType) { + case 'count_distinct': + case 'count_distinct_approx': + // ok, do nothing + break; + default: + // Can not change from count_distinct to avg/... + // Latter do care what input value is, and original measure can be defined on strings + throw new UserError( + `Unsupported measure type replacement for ${sourceMeasure}: ${source.type} => ${newMeasureType}` + ); + } + break; + default: + // Can not change from string, time, boolean, number + // Aggregation is already included in SQL, it's hard to patch that + // Can not change from count + // There's no SQL at all + throw new UserError( + `Unsupported measure type replacement for ${sourceMeasure}: ${source.type} => ${newMeasureType}` + ); + } + + resultMeasureType = newMeasureType; + } + + const resultFilters = source.filters ?? []; + + if (addFilters.length > 0) { + switch (resultMeasureType) { + case 'sum': + case 'avg': + case 'min': + case 'max': + case 'count': + case 'count_distinct': + case 'count_distinct_approx': + // ok, do nothing + break; + default: + // Can not add filters to string, time, boolean, number + // Aggregation is already included in SQL, it's hard to patch that + throw new UserError( + `Unsupported additional filters for measure ${sourceMeasure} type ${source.type}` + ); + } + + resultFilters.push(...addFilters); + } + + const patchedFrom = this.query.cubeEvaluator.parsePath('measures', sourceMeasure); + + return { + ...source, + type: resultMeasureType, + filters: resultFilters, + patchedFrom: { + cubeName: patchedFrom[0], + name: patchedFrom[1], + }, + }; + } + public constructor( protected readonly query: BaseQuery, public readonly measure: any @@ -20,6 +115,14 @@ export class BaseMeasure { // In case of SQL push down expressionName doesn't contain cube name. It's just a column name. this.expressionName = measure.expressionName || `${measure.cubeName}.${measure.name}`; this.isMemberExpression = !!measure.definition; + + if (measure.expression.type === 'PatchMeasure') { + this.patchedMeasure = this.preparePatchedMeasure( + measure.expression.sourceMeasure, + measure.expression.replaceAggregationType, + measure.expression.addFilters, + ); + } } } @@ -74,10 +177,16 @@ export class BaseMeasure { } public measureDefinition() { + if (this.patchedMeasure) { + return this.patchedMeasure; + } return this.query.cubeEvaluator.measureByPath(this.measure); } public definition(): any { + if (this.patchedMeasure) { + return this.patchedMeasure; + } if (this.expression) { return { sql: this.expression, diff --git a/packages/cubejs-schema-compiler/src/adapter/BaseQuery.js b/packages/cubejs-schema-compiler/src/adapter/BaseQuery.js index 8ec91a51fb1a3..1a3db657ba251 100644 --- a/packages/cubejs-schema-compiler/src/adapter/BaseQuery.js +++ b/packages/cubejs-schema-compiler/src/adapter/BaseQuery.js @@ -2166,9 +2166,14 @@ export class BaseQuery { } traverseSymbol(s) { - return s.path() ? - [s.cube().name].concat(this.evaluateSymbolSql(s.path()[0], s.path()[1], s.definition())) : - this.evaluateSql(s.cube().name, s.definition().sql); + // TODO why not just evaluateSymbolSql for every branch? + if (s.path()) { + return [s.cube().name].concat(this.evaluateSymbolSql(s.path()[0], s.path()[1], s.definition())); + } else if (s.patchedMeasure?.patchedFrom) { + return [s.patchedMeasure.patchedFrom.cubeName].concat(this.evaluateSymbolSql(s.patchedMeasure.patchedFrom.cubeName, s.patchedMeasure.patchedFrom.name, s.definition())); + } else { + return this.evaluateSql(s.cube().name, s.definition().sql); + } } collectCubeNames(excludeTimeDimensions) { @@ -2519,6 +2524,9 @@ export class BaseQuery { if (!memberExpressionType) { this.pushMemberNameForCollectionIfNecessary(cubeName, name); } + if (symbol.patchedFrom) { + this.pushMemberNameForCollectionIfNecessary(symbol.patchedFrom.cubeName, symbol.patchedFrom.name); + } const memberPathArray = [cubeName, name]; // Member path needs to be expanded to granularity if subPropertyName is provided. // Without this: infinite recursion with maximum call stack size exceeded. diff --git a/packages/cubejs-schema-compiler/src/compiler/CubeEvaluator.ts b/packages/cubejs-schema-compiler/src/compiler/CubeEvaluator.ts index c97d922f1c514..6b7c32f1243a0 100644 --- a/packages/cubejs-schema-compiler/src/compiler/CubeEvaluator.ts +++ b/packages/cubejs-schema-compiler/src/compiler/CubeEvaluator.ts @@ -56,6 +56,7 @@ export type MeasureDefinition = { reduceByReferences?: string[], addGroupByReferences?: string[], timeShiftReferences?: TimeShiftDefinitionReference[], + patchedFrom?: { cubeName: string, name: string }, }; export type PreAggregationFilters = { diff --git a/packages/cubejs-schema-compiler/test/integration/postgres/sql-generation.test.ts b/packages/cubejs-schema-compiler/test/integration/postgres/sql-generation.test.ts index 1a12ba758b5cb..4db15a7e3a359 100644 --- a/packages/cubejs-schema-compiler/test/integration/postgres/sql-generation.test.ts +++ b/packages/cubejs-schema-compiler/test/integration/postgres/sql-generation.test.ts @@ -3753,4 +3753,48 @@ SELECT 1 AS revenue, cast('2024-01-01' AS timestamp) as time UNION ALL ungrouped_measure_with_filter__view__count: 1, ungrouped_measure_with_filter__view__sum_filter: 1 }])); + + it('patched measure expression', async () => { + await runQueryTest( + { + measures: [ + 'visitors.revenue', + 'visitors.visitor_revenue', + { + expression: { + type: 'PatchMeasure', + sourceMeasure: 'visitors.revenue', + replaceAggregationType: 'max', + addFilters: [], + }, + cubeName: 'visitors', + name: 'max_revenue', + definition: 'PatchMeasure(visitors.revenue, max, [])', + }, + { + expression: { + type: 'PatchMeasure', + sourceMeasure: 'visitors.revenue', + replaceAggregationType: null, + addFilters: [ + { + sql: (visitors) => `${visitors.source} IN ('google', 'some')`, + }, + ], + }, + cubeName: 'visitors', + name: 'google_revenue', + // eslint-disable-next-line no-template-curly-in-string + definition: 'PatchMeasure(visitors.revenue, min, [${visitors.source} IN (\'google\', \'some\')])', + }, + ], + }, + [{ + visitors__revenue: '2000', + visitors__visitor_revenue: '300', + visitors__max_revenue: 500, + visitors__google_revenue: '600', + }] + ); + }); }); diff --git a/packages/cubejs-testing/test/__snapshots__/smoke-cubesql.test.ts.snap b/packages/cubejs-testing/test/__snapshots__/smoke-cubesql.test.ts.snap index 0c93be9f71419..54ebaadaa985b 100644 --- a/packages/cubejs-testing/test/__snapshots__/smoke-cubesql.test.ts.snap +++ b/packages/cubejs-testing/test/__snapshots__/smoke-cubesql.test.ts.snap @@ -546,6 +546,40 @@ Array [ ] `; +exports[`SQL API Postgres (Data) measure with ad-hoc filter and original measure: measure-with-ad-hoc-filters-and-original-measure 1`] = ` +Array [ + Object { + "new_amount": 300, + "total_amount": 1700, + }, +] +`; + +exports[`SQL API Postgres (Data) measure with ad-hoc filter: measure-with-ad-hoc-filters 1`] = ` +Array [ + Object { + "new_amount": 300, + }, +] +`; + +exports[`SQL API Postgres (Data) measure with replaced aggregation and original measure: measure-with-replaced-aggregation-and-original-measure 1`] = ` +Array [ + Object { + "min_amount": 100, + "sum_amount": 1700, + }, +] +`; + +exports[`SQL API Postgres (Data) measure with replaced aggregation: measure-with-replaced-aggregation 1`] = ` +Array [ + Object { + "min_amount": 100, + }, +] +`; + exports[`SQL API Postgres (Data) metabase max number: metabase max number 1`] = ` Array [ Object { diff --git a/packages/cubejs-testing/test/smoke-cubesql.test.ts b/packages/cubejs-testing/test/smoke-cubesql.test.ts index e8d71956921e4..bdc91523fe2f4 100644 --- a/packages/cubejs-testing/test/smoke-cubesql.test.ts +++ b/packages/cubejs-testing/test/smoke-cubesql.test.ts @@ -796,5 +796,67 @@ filter_subq AS ( const res = await connection.query(query); expect(res.rows).toMatchSnapshot('wrapper-duplicated-members'); }); + + test('measure with replaced aggregation', async () => { + const query = ` + SELECT + MIN(totalAmount) AS min_amount + FROM + Orders + `; + + const res = await connection.query(query); + expect(res.rows).toMatchSnapshot('measure-with-replaced-aggregation'); + }); + + test('measure with replaced aggregation and original measure', async () => { + const query = ` + SELECT + SUM(totalAmount) AS sum_amount, + MIN(totalAmount) AS min_amount + FROM + Orders + `; + + const res = await connection.query(query); + expect(res.rows).toMatchSnapshot('measure-with-replaced-aggregation-and-original-measure'); + }); + + test('measure with ad-hoc filter', async () => { + const query = ` + SELECT + SUM( + CASE status = 'new' + WHEN TRUE + THEN totalAmount + ELSE NULL + END + ) AS new_amount + FROM + Orders + `; + + const res = await connection.query(query); + expect(res.rows).toMatchSnapshot('measure-with-ad-hoc-filters'); + }); + + test('measure with ad-hoc filter and original measure', async () => { + const query = ` + SELECT + SUM(totalAmount) AS total_amount, + SUM( + CASE status = 'new' + WHEN TRUE + THEN totalAmount + ELSE NULL + END + ) AS new_amount + FROM + Orders + `; + + const res = await connection.query(query); + expect(res.rows).toMatchSnapshot('measure-with-ad-hoc-filters-and-original-measure'); + }); }); }); diff --git a/rust/cubesql/cubesql/src/compile/engine/df/snapshots/cubesql__compile__engine__df__wrapper__tests__member_expression_patch_measure.snap b/rust/cubesql/cubesql/src/compile/engine/df/snapshots/cubesql__compile__engine__df__wrapper__tests__member_expression_patch_measure.snap new file mode 100644 index 0000000000000..cd8b42c093160 --- /dev/null +++ b/rust/cubesql/cubesql/src/compile/engine/df/snapshots/cubesql__compile__engine__df__wrapper__tests__member_expression_patch_measure.snap @@ -0,0 +1,22 @@ +--- +source: cubesql/src/compile/engine/df/wrapper.rs +expression: "UngroupedMemberDef\n{\n cube_name: \"cube\".to_string(), alias: \"alias\".to_string(), expr:\n UngroupedMemberExpr::PatchMeasure(PatchMeasureDef\n {\n source_measure: \"cube.measure\".to_string(), replace_aggregation_type:\n None, add_filters:\n vec![SqlFunctionExpr\n {\n cube_params: vec![\"cube\".to_string()], sql:\n \"1 + 2 = 3\".to_string(),\n }],\n }), grouping_set: None,\n}" +--- +{ + "cubeName": "cube", + "alias": "alias", + "expr": { + "type": "PatchMeasure", + "sourceMeasure": "cube.measure", + "replaceAggregationType": null, + "addFilters": [ + { + "cubeParams": [ + "cube" + ], + "sql": "1 + 2 = 3" + } + ] + }, + "groupingSet": null +} diff --git a/rust/cubesql/cubesql/src/compile/engine/df/snapshots/cubesql__compile__engine__df__wrapper__tests__member_expression_sql.snap b/rust/cubesql/cubesql/src/compile/engine/df/snapshots/cubesql__compile__engine__df__wrapper__tests__member_expression_sql.snap new file mode 100644 index 0000000000000..3d939de4a6947 --- /dev/null +++ b/rust/cubesql/cubesql/src/compile/engine/df/snapshots/cubesql__compile__engine__df__wrapper__tests__member_expression_sql.snap @@ -0,0 +1,17 @@ +--- +source: cubesql/src/compile/engine/df/wrapper.rs +expression: "UngroupedMemberDef\n{\n cube_name: \"cube\".to_string(), alias: \"alias\".to_string(), expr:\n UngroupedMemberExpr::SqlFunction(SqlFunctionExpr\n {\n cube_params: vec![\"cube\".to_string(), \"other\".to_string()], sql:\n \"1 + 2\".to_string(),\n }), grouping_set: None,\n}" +--- +{ + "cubeName": "cube", + "alias": "alias", + "expr": { + "type": "SqlFunction", + "cubeParams": [ + "cube", + "other" + ], + "sql": "1 + 2" + }, + "groupingSet": null +} diff --git a/rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs b/rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs index c8af1cad5ef84..9b27bd928a472 100644 --- a/rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs +++ b/rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs @@ -1,6 +1,9 @@ use crate::{ compile::{ - engine::df::scan::{CubeScanNode, DataType, MemberField, WrappedSelectNode}, + engine::{ + df::scan::{CubeScanNode, DataType, MemberField, WrappedSelectNode}, + udf::{MEASURE_UDAF_NAME, PATCH_MEASURE_UDAF_NAME}, + }, rewrite::{ extract_exprlist_from_groupping_set, rules::{ @@ -66,11 +69,36 @@ pub struct SqlQuery { } #[derive(Debug, Clone, Serialize)] -struct UngrouppedMemberDef { +struct SqlFunctionExpr { + #[serde(rename = "cubeParams")] + cube_params: Vec, + sql: String, +} + +#[derive(Debug, Clone, Serialize)] +struct PatchMeasureDef { + #[serde(rename = "sourceMeasure")] + source_measure: String, + #[serde(rename = "replaceAggregationType")] + replace_aggregation_type: Option, + #[serde(rename = "addFilters")] + add_filters: Vec, +} + +#[derive(Debug, Clone, Serialize)] +#[serde(tag = "type")] +enum UngroupedMemberExpr { + SqlFunction(SqlFunctionExpr), + PatchMeasure(PatchMeasureDef), +} + +#[derive(Debug, Clone, Serialize)] +struct UngroupedMemberDef { + #[serde(rename = "cubeName")] cube_name: String, alias: String, - cube_params: Vec, - expr: String, + expr: UngroupedMemberExpr, + #[serde(rename = "groupingSet")] grouping_set: Option, } @@ -82,8 +110,10 @@ pub enum GroupingSetType { #[derive(Clone, Serialize, Debug, PartialEq, Eq)] pub struct GroupingSetDesc { + #[serde(rename = "groupType")] pub group_type: GroupingSetType, pub id: u64, + #[serde(rename = "subId")] pub sub_id: Option, } @@ -843,7 +873,6 @@ impl CubeScanWrapperNode { // Here it should just generate the literal // 2. It would not allow to provide aliases for expressions, instead it usually generates them let (expr, sql) = Self::generate_sql_for_expr( - plan.clone(), new_sql, generator.clone(), expr, @@ -1154,7 +1183,6 @@ impl CubeScanWrapperNode { })? .clone(); let (projection, sql) = Self::generate_column_expr( - plan.clone(), schema.clone(), projection_expr.clone(), sql, @@ -1168,7 +1196,6 @@ impl CubeScanWrapperNode { .await?; let flat_group_expr = extract_exprlist_from_groupping_set(&group_expr); let (group_by, sql) = Self::generate_column_expr( - plan.clone(), schema.clone(), flat_group_expr.clone(), sql, @@ -1181,8 +1208,21 @@ impl CubeScanWrapperNode { ) .await?; let group_descs = extract_group_type_from_groupping_set(&group_expr)?; + + let (patch_measures, aggr_expr, sql) = Self::extract_patch_measures( + schema.as_ref(), + aggr_expr, + sql, + generator.clone(), + column_remapping, + &mut next_remapper, + can_rename_columns, + push_to_cube_context, + subqueries_sql.clone(), + ) + .await?; + let (aggregate, sql) = Self::generate_column_expr( - plan.clone(), schema.clone(), aggr_expr.clone(), sql, @@ -1196,7 +1236,6 @@ impl CubeScanWrapperNode { .await?; let (filter, sql) = Self::generate_column_expr( - plan.clone(), schema.clone(), filter_expr.clone(), sql, @@ -1210,7 +1249,6 @@ impl CubeScanWrapperNode { .await?; let (window, sql) = Self::generate_column_expr( - plan.clone(), schema.clone(), window_expr.clone(), sql, @@ -1224,7 +1262,6 @@ impl CubeScanWrapperNode { .await?; let (order, mut sql) = Self::generate_column_expr( - plan.clone(), schema.clone(), order_expr.clone(), sql, @@ -1252,7 +1289,6 @@ impl CubeScanWrapperNode { { // Need to call generate_column_expr to apply column_remapping let (join_condition, new_sql) = Self::generate_column_expr( - plan.clone(), schema.clone(), [condition.clone()], sql, @@ -1345,6 +1381,11 @@ impl CubeScanWrapperNode { &ungrouped_scan_node.used_cubes, ) })) + .chain(patch_measures.into_iter().map( + |(def, cube, alias)| { + Self::patch_measure_expr(def, cube, alias) + }, + )) .collect::>()?, ), dimensions: Some( @@ -1393,38 +1434,28 @@ impl CubeScanWrapperNode { .. } => { let col_name = expr_name(&expr, &schema)?; - let aliased_column = aggr_expr - .iter() - .find_position(|e| { - expr_name(e, &schema).map(|n| n == col_name).unwrap_or(false) - }) - .map(|(i, _)| aggregate[i].clone()).or_else(|| { - projection_expr - .iter() - .find_position(|e| { - expr_name(e, &schema).map(|n| n == col_name).unwrap_or(false) - }) - .map(|(i, _)| { - projection[i].clone() - }) - }).or_else(|| { - flat_group_expr - .iter() - .find_position(|e| { - expr_name(e, &schema).map(|n| n == col_name).unwrap_or(false) - }) - .map(|(i, _)| group_by[i].clone()) - }).ok_or_else(|| { - DataFusionError::Execution(format!( - "Can't find column {} in projection {:?} or aggregate {:?} or group {:?}", - col_name, - projection_expr, - aggr_expr, - flat_group_expr - )) - })?; + + let find_column = |exprs: &[Expr], columns: &[(AliasedColumn, HashSet)]| -> Option { + exprs.iter().zip(columns.iter()) + .find(|(e, _c)| expr_name(e, &schema).map(|n| n == col_name).unwrap_or(false)) + .map(|(_e, c)| c.0.clone()) + }; + + // TODO handle patch measures collection here + let aliased_column = find_column(&aggr_expr, &aggregate) + .or_else(|| find_column(&projection_expr, &projection)) + .or_else(|| find_column(&flat_group_expr, &group_by)) + .ok_or_else(|| { + DataFusionError::Execution(format!( + "Can't find column {} in projection {:?} or aggregate {:?} or group {:?}", + col_name, + projection_expr, + aggr_expr, + flat_group_expr + )) + })?; Ok(vec![ - aliased_column.0.alias.clone(), + aliased_column.alias, if *asc { "asc".to_string() } else { "desc".to_string() }, ]) } @@ -1491,6 +1522,12 @@ impl CubeScanWrapperNode { request: load_request.clone(), }) } else { + if !patch_measures.is_empty() { + return Err(CubeError::internal(format!( + "Unexpected patch measures for non-push-to-Cube wrapped select: {patch_measures:?}", + ))); + } + let resulting_sql = generator .get_sql_templates() .select( @@ -1563,8 +1600,204 @@ impl CubeScanWrapperNode { }) } + fn get_patch_measure<'l>( + sql_query: SqlQuery, + sql_generator: Arc, + expr: &'l Expr, + push_to_cube_context: Option<&'l PushToCubeContext<'_>>, + subqueries: Arc>, + ) -> Pin< + Box< + dyn Future< + Output = result::Result< + (Option<(PatchMeasureDef, String)>, SqlQuery), + CubeError, + >, + > + Send + + 'l, + >, + > { + Box::pin(async move { + match expr { + Expr::Alias(inner, _alias) => { + Self::get_patch_measure( + sql_query, + sql_generator, + inner, + push_to_cube_context, + subqueries, + ) + .await + } + Expr::AggregateUDF { fun, args } => { + if fun.name != PATCH_MEASURE_UDAF_NAME { + return Ok((None, sql_query)); + } + + let Some(PushToCubeContext { + ungrouped_scan_node, + .. + }) = push_to_cube_context + else { + return Err(CubeError::internal(format!( + "Unexpected UDAF expression without push-to-Cube context: {}", + fun.name + ))); + }; + + let (measure, aggregation, filter) = match args.as_slice() { + [measure, aggregation, filter] => (measure, aggregation, filter), + _ => { + return Err(CubeError::internal(format!( + "Unexpected number arguments for UDAF: {}, {args:?}", + fun.name + ))) + } + }; + + let Expr::Column(measure_column) = measure else { + return Err(CubeError::internal(format!( + "First argument should be column expression: {}", + fun.name + ))); + }; + + let aggregation = match aggregation { + Expr::Literal(ScalarValue::Utf8(Some(aggregation))) => Some(aggregation), + Expr::Literal(ScalarValue::Null) => None, + _ => { + return Err(CubeError::internal(format!( + "Second argument should be Utf8 literal expression: {}", + fun.name + ))); + } + }; + + let (filters, sql_query) = match filter { + Expr::Literal(ScalarValue::Null) => (vec![], sql_query), + _ => { + let mut used_members = HashSet::new(); + let (filter, sql_query) = Self::generate_sql_for_expr( + sql_query, + sql_generator.clone(), + filter.clone(), + push_to_cube_context, + subqueries.clone(), + Some(&mut used_members), + ) + .await?; + + let used_cubes = Self::prepare_used_cubes(&used_members); + + ( + vec![SqlFunctionExpr { + cube_params: used_cubes, + sql: filter, + }], + sql_query, + ) + } + }; + + let member = + Self::find_member_in_ungrouped_scan(ungrouped_scan_node, measure_column)?; + + let MemberField::Member(member) = member else { + return Err(CubeError::internal(format!( + "First argument should reference member, not literal: {}", + fun.name + ))); + }; + + let (cube, _member) = member.split_once('.').ok_or_else(|| { + CubeError::internal(format!("Can't parse cube name from member {member}",)) + })?; + + Ok(( + Some(( + PatchMeasureDef { + source_measure: member.clone(), + replace_aggregation_type: aggregation.cloned(), + add_filters: filters, + }, + cube.to_string(), + )), + sql_query, + )) + } + _ => Ok((None, sql_query)), + } + }) + } + + async fn extract_patch_measures( + schema: &DFSchema, + exprs: impl IntoIterator, + mut sql_query: SqlQuery, + sql_generator: Arc, + column_remapping: Option<&ColumnRemapping>, + next_remapper: &mut Remapper, + can_rename_columns: bool, + push_to_cube_context: Option<&PushToCubeContext<'_>>, + subqueries: Arc>, + ) -> result::Result<(Vec<(PatchMeasureDef, String, String)>, Vec, SqlQuery), CubeError> + { + let mut patches = vec![]; + let mut other = vec![]; + + for original_expr in exprs { + let (patch_def, sql_query_next) = Self::get_patch_measure( + sql_query, + sql_generator.clone(), + &original_expr, + push_to_cube_context, + subqueries.clone(), + ) + .await?; + sql_query = sql_query_next; + if let Some((patch_def, cube)) = patch_def { + let (_expr, alias) = Self::remap_column_expression( + schema, + &original_expr, + column_remapping, + next_remapper, + can_rename_columns, + )?; + + patches.push((patch_def, cube, alias)); + } else { + other.push(original_expr); + } + } + + Ok((patches, other, sql_query)) + } + + fn remap_column_expression( + schema: &DFSchema, + original_expr: &Expr, + column_remapping: Option<&ColumnRemapping>, + next_remapper: &mut Remapper, + can_rename_columns: bool, + ) -> result::Result<(Expr, String), CubeError> { + let expr = if let Some(column_remapping) = column_remapping { + let mut expr = column_remapping.remap(original_expr)?; + if !can_rename_columns { + let original_alias = expr_name(original_expr, &schema)?; + if original_alias != expr_name(&expr, &schema)? { + expr = Expr::Alias(Box::new(expr), original_alias.clone()); + } + } + expr + } else { + original_expr.clone() + }; + let alias = next_remapper.add_expr(&schema, original_expr, &expr)?; + + Ok((expr, alias)) + } + async fn generate_column_expr( - plan: Arc, schema: DFSchemaRef, exprs: impl IntoIterator, mut sql: SqlQuery, @@ -1577,22 +1810,16 @@ impl CubeScanWrapperNode { ) -> result::Result<(Vec<(AliasedColumn, HashSet)>, SqlQuery), CubeError> { let mut aliased_columns = Vec::new(); for original_expr in exprs { - let expr = if let Some(column_remapping) = column_remapping { - let mut expr = column_remapping.remap(&original_expr)?; - if !can_rename_columns { - let original_alias = expr_name(&original_expr, &schema)?; - if original_alias != expr_name(&expr, &schema)? { - expr = Expr::Alias(Box::new(expr), original_alias.clone()); - } - } - expr - } else { - original_expr.clone() - }; + let (expr, alias) = Self::remap_column_expression( + schema.as_ref(), + &original_expr, + column_remapping, + next_remapper, + can_rename_columns, + )?; let mut used_members = HashSet::new(); let (expr_sql, new_sql_query) = Self::generate_sql_for_expr( - plan.clone(), sql, generator.clone(), expr.clone(), @@ -1605,7 +1832,6 @@ impl CubeScanWrapperNode { Self::escape_interpolation_quotes(expr_sql, push_to_cube_context.is_some()); sql = new_sql_query; - let alias = next_remapper.add_expr(&schema, &original_expr, &expr)?; aliased_columns.push(( AliasedColumn { expr: expr_sql, @@ -1617,18 +1843,22 @@ impl CubeScanWrapperNode { Ok((aliased_columns, sql)) } - fn make_member_def<'m>( - column: &AliasedColumn, - used_members: impl IntoIterator, - ungrouped_scan_cubes: &Vec, - ) -> Result { - let used_cubes = used_members + fn prepare_used_cubes<'m>(used_members: impl IntoIterator) -> Vec { + used_members .into_iter() .flat_map(|member| member.split_once('.')) .map(|(cube, _rest)| cube) .unique() .map(|cube| cube.to_string()) - .collect::>(); + .collect::>() + } + + fn make_member_def<'m>( + column: &AliasedColumn, + used_members: impl IntoIterator, + ungrouped_scan_cubes: &Vec, + ) -> Result { + let used_cubes = Self::prepare_used_cubes(used_members); let cube_name = used_cubes .first() .or_else(|| ungrouped_scan_cubes.first()) @@ -1640,11 +1870,13 @@ impl CubeScanWrapperNode { })? .clone(); - let res = UngrouppedMemberDef { + let res = UngroupedMemberDef { cube_name, alias: column.alias.clone(), - cube_params: used_cubes, - expr: column.expr.clone(), + expr: UngroupedMemberExpr::SqlFunction(SqlFunctionExpr { + cube_params: used_cubes, + sql: column.expr.clone(), + }), grouping_set: None, }; Ok(res) @@ -1670,6 +1902,21 @@ impl CubeScanWrapperNode { Ok(serde_json::json!(res).to_string()) } + fn patch_measure_expr( + def: PatchMeasureDef, + cube_name: String, + alias: String, + ) -> Result { + let res = UngroupedMemberDef { + cube_name, + alias, + expr: UngroupedMemberExpr::PatchMeasure(def), + grouping_set: None, + }; + + Ok(serde_json::json!(res).to_string()) + } + fn generate_sql_cast_expr( sql_generator: Arc, inner_expr: String, @@ -1715,7 +1962,6 @@ impl CubeScanWrapperNode { /// This function is async to be able to call to JS land, /// in case some SQL generation could not be done through Jinja pub fn generate_sql_for_expr<'ctx>( - plan: Arc, mut sql_query: SqlQuery, sql_generator: Arc, expr: Expr, @@ -1727,7 +1973,6 @@ impl CubeScanWrapperNode { match expr { Expr::Alias(expr, _) => { let (expr, sql_query) = Self::generate_sql_for_expr( - plan.clone(), sql_query, sql_generator.clone(), *expr, @@ -1766,7 +2011,6 @@ impl CubeScanWrapperNode { // So we can generate that as if it were regular column expression return Self::generate_sql_for_expr( - plan.clone(), sql_query, sql_generator.clone(), expr, @@ -1778,33 +2022,8 @@ impl CubeScanWrapperNode { } } - let field_index = ungrouped_scan_node - .schema - .fields() - .iter() - .find_position(|f| { - f.name() == &c.name - && match c.relation.as_ref() { - Some(r) => Some(r) == f.qualifier(), - None => true, - } - }) - .ok_or_else(|| { - DataFusionError::Internal(format!( - "Can't find column {} in ungrouped scan node", - c - )) - })? - .0; - let member = ungrouped_scan_node - .member_fields - .get(field_index) - .ok_or_else(|| { - DataFusionError::Internal(format!( - "Can't find member for column {} in ungrouped scan node", - c - )) - })?; + let member = Self::find_member_in_ungrouped_scan(ungrouped_scan_node, c)?; + match member { MemberField::Member(member) => { if let Some(used_members) = used_members { @@ -1814,7 +2033,6 @@ impl CubeScanWrapperNode { } MemberField::Literal(value) => { Self::generate_sql_for_expr( - plan.clone(), sql_query, sql_generator.clone(), Expr::Literal(value.clone()), @@ -1866,7 +2084,6 @@ impl CubeScanWrapperNode { // Expr::ScalarVariable(_, _) => {} Expr::BinaryExpr { left, op, right } => { let (left, sql_query) = Self::generate_sql_for_expr( - plan.clone(), sql_query, sql_generator.clone(), *left, @@ -1876,7 +2093,6 @@ impl CubeScanWrapperNode { ) .await?; let (right, sql_query) = Self::generate_sql_for_expr( - plan.clone(), sql_query, sql_generator.clone(), *right, @@ -1899,7 +2115,6 @@ impl CubeScanWrapperNode { // Expr::AnyExpr { .. } => {} Expr::Like(like) => { let (expr, sql_query) = Self::generate_sql_for_expr( - plan.clone(), sql_query, sql_generator.clone(), *like.expr, @@ -1909,7 +2124,6 @@ impl CubeScanWrapperNode { ) .await?; let (pattern, sql_query) = Self::generate_sql_for_expr( - plan.clone(), sql_query, sql_generator.clone(), *like.pattern, @@ -1921,7 +2135,6 @@ impl CubeScanWrapperNode { let (escape_char, sql_query) = match like.escape_char { Some(escape_char) => { let (escape_char, sql_query) = Self::generate_sql_for_expr( - plan.clone(), sql_query, sql_generator.clone(), Expr::Literal(ScalarValue::Utf8(Some(escape_char.to_string()))), @@ -1947,7 +2160,6 @@ impl CubeScanWrapperNode { } Expr::ILike(ilike) => { let (expr, sql_query) = Self::generate_sql_for_expr( - plan.clone(), sql_query, sql_generator.clone(), *ilike.expr, @@ -1957,7 +2169,6 @@ impl CubeScanWrapperNode { ) .await?; let (pattern, sql_query) = Self::generate_sql_for_expr( - plan.clone(), sql_query, sql_generator.clone(), *ilike.pattern, @@ -1969,7 +2180,6 @@ impl CubeScanWrapperNode { let (escape_char, sql_query) = match ilike.escape_char { Some(escape_char) => { let (escape_char, sql_query) = Self::generate_sql_for_expr( - plan.clone(), sql_query, sql_generator.clone(), Expr::Literal(ScalarValue::Utf8(Some(escape_char.to_string()))), @@ -1996,7 +2206,6 @@ impl CubeScanWrapperNode { // Expr::SimilarTo(_) => {} Expr::Not(expr) => { let (expr, sql_query) = Self::generate_sql_for_expr( - plan.clone(), sql_query, sql_generator.clone(), *expr, @@ -2019,7 +2228,6 @@ impl CubeScanWrapperNode { } Expr::IsNotNull(expr) => { let (expr, sql_query) = Self::generate_sql_for_expr( - plan.clone(), sql_query, sql_generator.clone(), *expr, @@ -2041,7 +2249,6 @@ impl CubeScanWrapperNode { } Expr::IsNull(expr) => { let (expr, sql_query) = Self::generate_sql_for_expr( - plan.clone(), sql_query, sql_generator.clone(), *expr, @@ -2063,7 +2270,6 @@ impl CubeScanWrapperNode { } Expr::Negative(expr) => { let (expr, sql_query) = Self::generate_sql_for_expr( - plan.clone(), sql_query, sql_generator.clone(), *expr, @@ -2092,7 +2298,6 @@ impl CubeScanWrapperNode { } => { let expr = if let Some(expr) = expr { let (expr, sql_query_next) = Self::generate_sql_for_expr( - plan.clone(), sql_query, sql_generator.clone(), *expr, @@ -2109,7 +2314,6 @@ impl CubeScanWrapperNode { let mut when_then_expr_sql = Vec::new(); for (when, then) in when_then_expr { let (when, sql_query_next) = Self::generate_sql_for_expr( - plan.clone(), sql_query, sql_generator.clone(), *when, @@ -2119,7 +2323,6 @@ impl CubeScanWrapperNode { ) .await?; let (then, sql_query_next) = Self::generate_sql_for_expr( - plan.clone(), sql_query_next, sql_generator.clone(), *then, @@ -2133,7 +2336,6 @@ impl CubeScanWrapperNode { } let else_expr = if let Some(else_expr) = else_expr { let (else_expr, sql_query_next) = Self::generate_sql_for_expr( - plan.clone(), sql_query, sql_generator.clone(), *else_expr, @@ -2157,7 +2359,6 @@ impl CubeScanWrapperNode { } Expr::Cast { expr, data_type } => { let (expr, sql_query) = Self::generate_sql_for_expr( - plan.clone(), sql_query, sql_generator.clone(), *expr, @@ -2178,7 +2379,6 @@ impl CubeScanWrapperNode { nulls_first, } => { let (expr, sql_query) = Self::generate_sql_for_expr( - plan.clone(), sql_query, sql_generator.clone(), *expr, @@ -2586,7 +2786,6 @@ impl CubeScanWrapperNode { let mut sql_args = Vec::new(); for arg in args { let (sql, query) = Self::generate_sql_for_expr( - plan.clone(), sql_query, sql_generator.clone(), arg, @@ -2624,7 +2823,6 @@ impl CubeScanWrapperNode { ))); } let (arg_sql, query) = Self::generate_sql_for_expr( - plan.clone(), sql_query, sql_generator.clone(), args[1].clone(), @@ -2668,7 +2866,6 @@ impl CubeScanWrapperNode { let mut sql_args = Vec::new(); for arg in args { let (sql, query) = Self::generate_sql_for_expr( - plan.clone(), sql_query, sql_generator.clone(), arg, @@ -2709,7 +2906,6 @@ impl CubeScanWrapperNode { } } let (sql, query) = Self::generate_sql_for_expr( - plan.clone(), sql_query, sql_generator.clone(), arg, @@ -2739,7 +2935,6 @@ impl CubeScanWrapperNode { let mut sql_exprs = Vec::new(); for expr in exprs { let (sql, query) = Self::generate_sql_for_expr( - plan.clone(), sql_query, sql_generator.clone(), expr, @@ -2768,7 +2963,6 @@ impl CubeScanWrapperNode { let mut sql_exprs = Vec::new(); for expr in exprs { let (sql, query) = Self::generate_sql_for_expr( - plan.clone(), sql_query, sql_generator.clone(), expr, @@ -2810,7 +3004,6 @@ impl CubeScanWrapperNode { let mut sql_args = Vec::new(); for arg in args { let (sql, query) = Self::generate_sql_for_expr( - plan.clone(), sql_query, sql_generator.clone(), arg, @@ -2825,7 +3018,6 @@ impl CubeScanWrapperNode { let mut sql_partition_by = Vec::new(); for arg in partition_by { let (sql, query) = Self::generate_sql_for_expr( - plan.clone(), sql_query, sql_generator.clone(), arg, @@ -2840,7 +3032,6 @@ impl CubeScanWrapperNode { let mut sql_order_by = Vec::new(); for arg in order_by { let (sql, query) = Self::generate_sql_for_expr( - plan.clone(), sql_query, sql_generator.clone(), arg, @@ -2869,7 +3060,56 @@ impl CubeScanWrapperNode { })?; Ok((resulting_sql, sql_query)) } - // Expr::AggregateUDF { .. } => {} + Expr::AggregateUDF { ref fun, ref args } => { + match fun.name.as_str() { + // TODO allow this only in agg expr + MEASURE_UDAF_NAME => { + let Some(PushToCubeContext { + ungrouped_scan_node, + .. + }) = push_to_cube_context + else { + return Err(DataFusionError::Internal(format!( + "Unexpected {} UDAF expression without push-to-Cube context: {expr}", + fun.name, + ))); + }; + + let measure_column = match args.as_slice() { + [Expr::Column(measure_column)] => measure_column, + _ => { + return Err(DataFusionError::Internal(format!( + "Unexpected arguments for {} UDAF: {expr}", + fun.name, + ))) + } + }; + + let member = Self::find_member_in_ungrouped_scan( + ungrouped_scan_node, + measure_column, + )?; + + let MemberField::Member(member) = member else { + return Err(DataFusionError::Internal(format!( + "First argument for {} UDAF should reference member, not literal: {expr}", + fun.name, + ))); + }; + + if let Some(used_members) = used_members { + used_members.insert(member.clone()); + } + + Ok((format!("${{{member}}}"), sql_query)) + } + // There's no branch for PatchMeasure, because it should generate via different path + _ => Err(DataFusionError::Internal(format!( + "Can't generate SQL for UDAF: {}", + fun.name + ))), + } + } Expr::InList { expr, list, @@ -2877,7 +3117,6 @@ impl CubeScanWrapperNode { } => { let mut sql_query = sql_query; let (sql_expr, query) = Self::generate_sql_for_expr( - plan.clone(), sql_query, sql_generator.clone(), *expr, @@ -2890,7 +3129,6 @@ impl CubeScanWrapperNode { let mut sql_in_exprs = Vec::new(); for expr in list { let (sql, query) = Self::generate_sql_for_expr( - plan.clone(), sql_query, sql_generator.clone(), expr, @@ -2922,7 +3160,6 @@ impl CubeScanWrapperNode { } => { let mut sql_query = sql_query; let (sql_expr, query) = Self::generate_sql_for_expr( - plan.clone(), sql_query, sql_generator.clone(), *expr, @@ -2933,7 +3170,6 @@ impl CubeScanWrapperNode { .await?; sql_query = query; let (subquery_sql, query) = Self::generate_sql_for_expr( - plan.clone(), sql_query, sql_generator.clone(), *subquery, @@ -2969,6 +3205,31 @@ impl CubeScanWrapperNode { }) } + fn find_member_in_ungrouped_scan<'scan, 'col>( + ungrouped_scan_node: &'scan CubeScanNode, + column: &'col Column, + ) -> Result<&'scan MemberField> { + let (_field, member) = ungrouped_scan_node + .schema + .fields() + .iter() + .zip(ungrouped_scan_node.member_fields.iter()) + .find(|(f, _mf)| { + f.name() == &column.name + && match column.relation.as_ref() { + Some(r) => Some(r) == f.qualifier(), + None => true, + } + }) + .ok_or_else(|| { + DataFusionError::Internal(format!( + "Can't find member for column {column} in ungrouped scan node" + )) + })?; + + Ok(member) + } + fn escape_interpolation_quotes(s: String, ungrouped: bool) -> String { if ungrouped { s.replace("\\", "\\\\").replace("`", "\\`") @@ -3017,3 +3278,38 @@ impl UserDefinedLogicalNode for CubeScanWrapperNode { }) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_member_expression_sql() { + insta::assert_json_snapshot!(UngroupedMemberDef { + cube_name: "cube".to_string(), + alias: "alias".to_string(), + expr: UngroupedMemberExpr::SqlFunction(SqlFunctionExpr { + cube_params: vec!["cube".to_string(), "other".to_string()], + sql: "1 + 2".to_string(), + }), + grouping_set: None, + }); + } + + #[test] + fn test_member_expression_patch_measure() { + insta::assert_json_snapshot!(UngroupedMemberDef { + cube_name: "cube".to_string(), + alias: "alias".to_string(), + expr: UngroupedMemberExpr::PatchMeasure(PatchMeasureDef { + source_measure: "cube.measure".to_string(), + replace_aggregation_type: None, + add_filters: vec![SqlFunctionExpr { + cube_params: vec!["cube".to_string()], + sql: "1 + 2 = 3".to_string(), + }], + }), + grouping_set: None, + }); + } +} diff --git a/rust/cubesql/cubesql/src/compile/engine/udf/common.rs b/rust/cubesql/cubesql/src/compile/engine/udf/common.rs index 853252f04a266..096ea843f7f0e 100644 --- a/rust/cubesql/cubesql/src/compile/engine/udf/common.rs +++ b/rust/cubesql/cubesql/src/compile/engine/udf/common.rs @@ -24,7 +24,7 @@ use datafusion::{ }, error::{DataFusionError, Result}, execution::context::SessionContext, - logical_plan::{create_udaf, create_udf}, + logical_plan::create_udf, physical_plan::{ functions::{ datetime_expressions::date_trunc, make_scalar_function, make_table_function, Signature, @@ -2259,14 +2259,66 @@ pub fn create_pg_get_constraintdef_udf() -> ScalarUDF { ) } +pub const MEASURE_UDAF_NAME: &str = "measure"; + pub fn create_measure_udaf() -> AggregateUDF { - create_udaf( - "measure", - DataType::Float64, - Arc::new(DataType::Float64), - Volatility::Immutable, - Arc::new(|| todo!("Not implemented")), - Arc::new(vec![DataType::Float64]), + let signature = Signature::any(1, Volatility::Immutable); + + // MEASURE(cube.measure) should have same type as just cube.measure + let return_type: ReturnTypeFunction = Arc::new(move |inputs| { + if inputs.len() != 1 { + Err(DataFusionError::Internal(format!( + "Unexpected argument types for MEASURE: {inputs:?}" + ))) + } else { + Ok(Arc::new(inputs[0].clone())) + } + }); + + let accumulator: AccumulatorFunctionImplementation = Arc::new(|| todo!("Not implemented")); + + let state_type = Arc::new(vec![DataType::Float64]); + let state_type: StateTypeFunction = Arc::new(move |_| Ok(state_type.clone())); + + AggregateUDF::new( + MEASURE_UDAF_NAME, + &signature, + &return_type, + &accumulator, + &state_type, + ) +} + +pub const PATCH_MEASURE_UDAF_NAME: &str = "__patch_measure"; + +// TODO add sanity check on incoming query to disallow it in input +pub fn create_patch_measure_udaf() -> AggregateUDF { + // TODO actually signature should look like (any, text, boolean) + let signature = Signature::any(3, Volatility::Immutable); + + // __PATCH_MEASURE(cube.measure, type, filter) should have same type as just cube.measure + let return_type: ReturnTypeFunction = Arc::new(move |inputs| { + if inputs.len() != 3 { + Err(DataFusionError::Internal(format!( + "Unexpected argument types for {PATCH_MEASURE_UDAF_NAME}: {inputs:?}" + ))) + } else { + Ok(Arc::new(inputs[0].clone())) + } + }); + + let accumulator: AccumulatorFunctionImplementation = + Arc::new(|| todo!("Internal, should not execute")); + + let state_type = Arc::new(vec![DataType::Float64]); + let state_type: StateTypeFunction = Arc::new(move |_| Ok(state_type.clone())); + + AggregateUDF::new( + PATCH_MEASURE_UDAF_NAME, + &signature, + &return_type, + &accumulator, + &state_type, ) } diff --git a/rust/cubesql/cubesql/src/compile/mod.rs b/rust/cubesql/cubesql/src/compile/mod.rs index b3840407504d3..d145131eeab39 100644 --- a/rust/cubesql/cubesql/src/compile/mod.rs +++ b/rust/cubesql/cubesql/src/compile/mod.rs @@ -2783,17 +2783,14 @@ limit #[tokio::test] async fn test_select_error() { - let variants = [ - ( - "SELECT AVG(maxPrice) FROM KibanaSampleDataEcommerce".to_string(), - CompilationError::user("Error during rewrite: Measure aggregation type doesn't match. The aggregation type for 'maxPrice' is 'MAX()' but 'AVG()' was provided. Please check logs for additional information.".to_string()), - ), + let variants: &[(&str, _)] = &[ + // TODO are there any errors that we could test for? ]; - for (input_query, expected_error) in variants.iter() { + for (input_query, expected_error) in variants { let meta = get_test_tenant_ctx(); let query = convert_sql_to_cube_query( - &input_query, + input_query, meta.clone(), get_test_session(DatabaseProtocol::PostgreSQL, meta).await, ) @@ -7229,11 +7226,14 @@ ORDER BY fn trivial_member_expr(cube: &str, member: &str, alias: &str) -> String { json!({ - "cube_name": cube, + "cubeName": cube, "alias": alias, - "cube_params": [cube], - "expr": format!("${{{cube}.{member}}}"), - "grouping_set": null, + "expr": { + "type": "SqlFunction", + "cubeParams": [cube], + "sql": format!("${{{cube}.{member}}}"), + }, + "groupingSet": null, }) .to_string() } @@ -7246,35 +7246,47 @@ ORDER BY V1LoadRequestQuery { measures: Some(vec![ json!({ - "cube_name": "WideCube", + "cubeName": "WideCube", "alias": "max_source_measu", - "cube_params": ["WideCube"], - "expr": "${WideCube.measure1}", - "grouping_set": null, + "expr": { + "type": "SqlFunction", + "cubeParams": ["WideCube"], + "sql": "${WideCube.measure1}", + }, + "groupingSet": null, }) .to_string(), json!({ - "cube_name": "WideCube", + "cubeName": "WideCube", "alias": "max_source_measu_1", - "cube_params": ["WideCube"], - "expr": "${WideCube.measure2}", - "grouping_set": null, + "expr": { + "type": "SqlFunction", + "cubeParams": ["WideCube"], + "sql": "${WideCube.measure2}", + }, + "groupingSet": null, }) .to_string(), json!({ - "cube_name": "WideCube", + "cubeName": "WideCube", "alias": "sum_source_measu", - "cube_params": ["WideCube"], - "expr": "${WideCube.measure3}", - "grouping_set": null, + "expr": { + "type": "SqlFunction", + "cubeParams": ["WideCube"], + "sql": "${WideCube.measure3}", + }, + "groupingSet": null, }) .to_string(), json!({ - "cube_name": "WideCube", + "cubeName": "WideCube", "alias": "max_source_measu_2", - "cube_params": ["WideCube"], - "expr": "${WideCube.measure4}", - "grouping_set": null, + "expr": { + "type": "SqlFunction", + "cubeParams": ["WideCube"], + "sql": "${WideCube.measure4}", + }, + "groupingSet": null, }) .to_string(), ]), @@ -7283,11 +7295,14 @@ ORDER BY trivial_member_expr("WideCube", "dim3", "dim3"), trivial_member_expr("WideCube", "dim4", "dim4"), json!({ - "cube_name": "WideCube", + "cubeName": "WideCube", "alias": "pivot_grouping", - "cube_params": [], - "expr": "0", - "grouping_set": null, + "expr": { + "type": "SqlFunction", + "cubeParams": [], + "sql": "0", + }, + "groupingSet": null, }) .to_string() ]), @@ -11807,20 +11822,26 @@ ORDER BY "source"."str0" ASC measures: Some(vec![]), dimensions: Some(vec![ json!({ - "cube_name": "KibanaSampleDataEcommerce", + "cubeName": "KibanaSampleDataEcommerce", "alias": "ta_1_order_date_", - "cube_params": ["KibanaSampleDataEcommerce"], - "expr": "((${KibanaSampleDataEcommerce.order_date} = DATE('1994-05-01')) OR (${KibanaSampleDataEcommerce.order_date} = DATE('1996-05-03')))", - "grouping_set": null, + "expr": { + "type": "SqlFunction", + "cubeParams": ["KibanaSampleDataEcommerce"], + "sql": "((${KibanaSampleDataEcommerce.order_date} = DATE('1994-05-01')) OR (${KibanaSampleDataEcommerce.order_date} = DATE('1996-05-03')))", + }, + "groupingSet": null, }).to_string(), ]), segments: Some(vec![ json!({ - "cube_name": "Logs", + "cubeName": "Logs", "alias": "lower_ta_2_conte", - "cube_params": ["Logs"], - "expr": "(LOWER(${Logs.content}) = $0$)", - "grouping_set": null, + "expr": { + "type": "SqlFunction", + "cubeParams": ["Logs"], + "sql": "(LOWER(${Logs.content}) = $0$)", + }, + "groupingSet": null, }).to_string(), ]), order: Some(vec![]), @@ -12088,34 +12109,46 @@ ORDER BY "source"."str0" ASC measures: Some(vec![]), dimensions: Some(vec![ json!({ - "cube_name": "KibanaSampleDataEcommerce", + "cubeName": "KibanaSampleDataEcommerce", "alias": "customer_gender", - "cube_params": ["KibanaSampleDataEcommerce"], - "expr": "${KibanaSampleDataEcommerce.customer_gender}", - "grouping_set": null, + "expr": { + "type": "SqlFunction", + "cubeParams": ["KibanaSampleDataEcommerce"], + "sql": "${KibanaSampleDataEcommerce.customer_gender}", + }, + "groupingSet": null, }).to_string(), json!({ - "cube_name": "KibanaSampleDataEcommerce", + "cubeName": "KibanaSampleDataEcommerce", "alias": "cast_dateadd_utf", - "cube_params": ["KibanaSampleDataEcommerce"], - "expr": "CAST(DATE_ADD(${KibanaSampleDataEcommerce.order_date}, INTERVAL '2 DAY') AS DATE)", - "grouping_set": null, + "expr": { + "type": "SqlFunction", + "cubeParams": ["KibanaSampleDataEcommerce"], + "sql": "CAST(DATE_ADD(${KibanaSampleDataEcommerce.order_date}, INTERVAL '2 DAY') AS DATE)", + }, + "groupingSet": null, }).to_string(), json!({ - "cube_name": "KibanaSampleDataEcommerce", + "cubeName": "KibanaSampleDataEcommerce", "alias": "dateadd_utf8__se", - "cube_params": ["KibanaSampleDataEcommerce"], - "expr": "DATE_ADD(${KibanaSampleDataEcommerce.order_date}, INTERVAL '2000000 MILLISECOND')", - "grouping_set": null, + "expr": { + "type": "SqlFunction", + "cubeParams": ["KibanaSampleDataEcommerce"], + "sql": "DATE_ADD(${KibanaSampleDataEcommerce.order_date}, INTERVAL '2000000 MILLISECOND')", + }, + "groupingSet": null, }).to_string(), ]), segments: Some(vec![ json!({ - "cube_name": "KibanaSampleDataEcommerce", + "cubeName": "KibanaSampleDataEcommerce", "alias": "dateadd_utf8__da", - "cube_params": ["KibanaSampleDataEcommerce"], - "expr": "(DATE_ADD(${KibanaSampleDataEcommerce.order_date}, INTERVAL '2 DAY') < DATE('2014-06-02'))", - "grouping_set": null, + "expr": { + "type": "SqlFunction", + "cubeParams": ["KibanaSampleDataEcommerce"], + "sql": "(DATE_ADD(${KibanaSampleDataEcommerce.order_date}, INTERVAL '2 DAY') < DATE('2014-06-02'))", + }, + "groupingSet": null, }).to_string(), ]), order: Some(vec![]), @@ -12914,29 +12947,38 @@ ORDER BY "source"."str0" ASC V1LoadRequestQuery { measures: Some(vec![ json!({ - "cube_name": "KibanaSampleDataEcommerce", + "cubeName": "KibanaSampleDataEcommerce", "alias": "avg_kibanasample", - "cube_params": ["KibanaSampleDataEcommerce"], - "expr": "${KibanaSampleDataEcommerce.avgPrice}", - "grouping_set": null, + "expr": { + "type": "SqlFunction", + "cubeParams": ["KibanaSampleDataEcommerce"], + "sql": "${KibanaSampleDataEcommerce.avgPrice}", + }, + "groupingSet": null, }).to_string(), ]), dimensions: Some(vec![ json!({ - "cube_name": "KibanaSampleDataEcommerce", + "cubeName": "KibanaSampleDataEcommerce", "alias": "cast_kibanasampl", - "cube_params": ["KibanaSampleDataEcommerce"], - "expr": "CAST(${KibanaSampleDataEcommerce.order_date} AS DATE)", - "grouping_set": null, + "expr": { + "type": "SqlFunction", + "cubeParams": ["KibanaSampleDataEcommerce"], + "sql": "CAST(${KibanaSampleDataEcommerce.order_date} AS DATE)", + }, + "groupingSet": null, }).to_string(), ]), segments: Some(vec![ json!({ - "cube_name": "KibanaSampleDataEcommerce", + "cubeName": "KibanaSampleDataEcommerce", "alias": "kibanasampledata", - "cube_params": ["KibanaSampleDataEcommerce"], - "expr": format!("(((${{KibanaSampleDataEcommerce.order_date}} >= CAST((NOW() + INTERVAL '-30 DAY') AS DATE)) AND (${{KibanaSampleDataEcommerce.order_date}} < CAST(NOW() AS DATE))) AND (((${{KibanaSampleDataEcommerce.notes}} = $0$) OR (${{KibanaSampleDataEcommerce.notes}} = $1$)) OR (${{KibanaSampleDataEcommerce.notes}} = $2$)))"), - "grouping_set": null, + "expr": { + "type": "SqlFunction", + "cubeParams": ["KibanaSampleDataEcommerce"], + "sql": format!("(((${{KibanaSampleDataEcommerce.order_date}} >= CAST((NOW() + INTERVAL '-30 DAY') AS DATE)) AND (${{KibanaSampleDataEcommerce.order_date}} < CAST(NOW() AS DATE))) AND (((${{KibanaSampleDataEcommerce.notes}} = $0$) OR (${{KibanaSampleDataEcommerce.notes}} = $1$)) OR (${{KibanaSampleDataEcommerce.notes}} = $2$)))"), + }, + "groupingSet": null, }).to_string(), ]), order: Some(vec![]), diff --git a/rust/cubesql/cubesql/src/compile/query_engine.rs b/rust/cubesql/cubesql/src/compile/query_engine.rs index 4de9b441e9a58..b6b3773ab7186 100644 --- a/rust/cubesql/cubesql/src/compile/query_engine.rs +++ b/rust/cubesql/cubesql/src/compile/query_engine.rs @@ -516,6 +516,7 @@ impl QueryEngine for SqlQueryEngine { // udaf ctx.register_udaf(create_measure_udaf()); + ctx.register_udaf(create_patch_measure_udaf()); // udtf ctx.register_udtf(create_generate_series_udtf()); diff --git a/rust/cubesql/cubesql/src/compile/rewrite/analysis.rs b/rust/cubesql/cubesql/src/compile/rewrite/analysis.rs index a96b03403e1e0..1363fc9b6c256 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/analysis.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/analysis.rs @@ -1,5 +1,6 @@ use crate::{ compile::{ + engine::udf::MEASURE_UDAF_NAME, rewrite::{ converter::{is_expr_node, node_to_expr, LogicalPlanToLanguageConverter}, expr_column_name, @@ -389,7 +390,7 @@ impl LogicalPlanAnalysis { Some(trivial) } LogicalPlanLanguage::AggregateUDFExprFun(AggregateUDFExprFun(fun)) => { - if fun.to_lowercase() == "measure" { + if fun.to_lowercase() == MEASURE_UDAF_NAME { Some(0) } else { None diff --git a/rust/cubesql/cubesql/src/compile/rewrite/language.rs b/rust/cubesql/cubesql/src/compile/rewrite/language.rs index 5622a02377c40..81ccd7470db7d 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/language.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/language.rs @@ -261,8 +261,12 @@ macro_rules! variant_field_struct { impl FromStr for [<$variant $var_field:camel>] { type Err = $crate::compile::rewrite::language::LanguageParseError; - fn from_str(_s: &str) -> Result { - Err(Self::Err::NotSupported) + fn from_str(s: &str) -> Result { + const PREFIX: &'static str = concat!(std::stringify!([<$variant $var_field:camel>]), ":"); + if let Some(suffix) = s.strip_prefix(PREFIX) { + return Ok([<$variant $var_field:camel>](suffix.to_string())); + } + Err(Self::Err::ShouldStartWith(PREFIX)) } } @@ -640,6 +644,8 @@ macro_rules! variant_field_struct { } else if let Some(value) = typed_str.strip_prefix("f:") { let n: f64 = value.parse().map_err(|err| Self::Err::InvalidFloatValue(err))?; Ok([<$variant $var_field:camel>](ScalarValue::Float64(Some(n)))) + } else if typed_str == "null" { + Ok([<$variant $var_field:camel>](ScalarValue::Null)) } else { Err(Self::Err::InvalidScalarType) } diff --git a/rust/cubesql/cubesql/src/compile/rewrite/mod.rs b/rust/cubesql/cubesql/src/compile/rewrite/mod.rs index 1ff66f33b939b..38b05d7dfc34c 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/mod.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/mod.rs @@ -1419,9 +1419,13 @@ fn window_fun_expr_var_arg( } fn udaf_expr(fun_name: impl Display, args: Vec) -> String { + let prefix = if fun_name.to_string().starts_with("?") { + "" + } else { + "AggregateUDFExprFun:" + }; format!( - "(AggregateUDFExpr {} {})", - fun_name, + "(AggregateUDFExpr {prefix}{fun_name} {})", list_expr("AggregateUDFExprArgs", args), ) } @@ -1776,6 +1780,10 @@ fn literal_bool(literal_bool: bool) -> String { format!("(LiteralExpr LiteralExprValue:b:{})", literal_bool) } +fn literal_null() -> String { + format!("(LiteralExpr LiteralExprValue:null)") +} + fn projection( expr: impl Display, input: impl Display, diff --git a/rust/cubesql/cubesql/src/compile/rewrite/rules/members.rs b/rust/cubesql/cubesql/src/compile/rewrite/rules/members.rs index c128ae11dc8aa..1196334ec329a 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/rules/members.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/rules/members.rs @@ -1,31 +1,35 @@ use crate::{ - compile::rewrite::{ - agg_fun_expr, aggregate, alias_expr, all_members, - analysis::{ConstantFolding, LogicalPlanData, Member, MemberNamesToExpr, OriginalExpr}, - binary_expr, cast_expr, change_user_expr, column_expr, cross_join, cube_scan, - cube_scan_filters_empty_tail, cube_scan_members, cube_scan_members_empty_tail, - cube_scan_order_empty_tail, dimension_expr, distinct, expr_column_name, fun_expr, join, - like_expr, limit, list_concat_pushdown_replacer, list_concat_pushup_replacer, literal_expr, - literal_member, measure_expr, member_pushdown_replacer, member_replacer, - merged_members_replacer, original_expr_name, projection, referenced_columns, rewrite, - rewriter::{CubeEGraph, CubeRewrite, RewriteRules}, - rules::{ - replacer_flat_push_down_node_substitute_rules, replacer_push_down_node, - replacer_push_down_node_substitute_rules, utils, + compile::{ + engine::udf::MEASURE_UDAF_NAME, + rewrite::{ + agg_fun_expr, aggregate, alias_expr, all_members, + analysis::{ConstantFolding, LogicalPlanData, Member, MemberNamesToExpr, OriginalExpr}, + binary_expr, cast_expr, change_user_expr, column_expr, cross_join, cube_scan, + cube_scan_filters_empty_tail, cube_scan_members, cube_scan_members_empty_tail, + cube_scan_order_empty_tail, dimension_expr, distinct, expr_column_name, fun_expr, join, + like_expr, limit, list_concat_pushdown_replacer, list_concat_pushup_replacer, + literal_expr, literal_member, measure_expr, member_pushdown_replacer, member_replacer, + merged_members_replacer, original_expr_name, projection, referenced_columns, rewrite, + rewriter::{CubeEGraph, CubeRewrite, RewriteRules}, + rules::{ + replacer_flat_push_down_node_substitute_rules, replacer_push_down_node, + replacer_push_down_node_substitute_rules, utils, + }, + segment_expr, table_scan, time_dimension_expr, transform_original_expr_to_alias, + transforming_chain_rewrite, transforming_rewrite, transforming_rewrite_with_root, + udaf_expr, udf_expr, virtual_field_expr, AggregateFunctionExprDistinct, + AggregateFunctionExprFun, AliasExprAlias, AllMembersAlias, AllMembersCube, + BinaryExprOp, CastExprDataType, ChangeUserCube, ColumnExprColumn, CubeScanAliasToCube, + CubeScanCanPushdownJoin, CubeScanLimit, CubeScanOffset, CubeScanUngrouped, + DimensionName, JoinLeftOn, JoinRightOn, LikeExprEscapeChar, LikeExprLikeType, + LikeExprNegated, LikeType, LimitFetch, LimitSkip, ListType, LiteralExprValue, + LiteralMemberRelation, LiteralMemberValue, LogicalPlanLanguage, MeasureName, + MemberErrorAliasToCube, MemberErrorError, MemberErrorPriority, + MemberPushdownReplacerAliasToCube, MemberReplacerAliasToCube, ProjectionAlias, + SegmentName, TableScanFetch, TableScanProjection, TableScanSourceTableName, + TableScanTableName, TimeDimensionDateRange, TimeDimensionGranularity, + TimeDimensionName, VirtualFieldCube, VirtualFieldName, }, - segment_expr, table_scan, time_dimension_expr, transform_original_expr_to_alias, - transforming_chain_rewrite, transforming_rewrite, transforming_rewrite_with_root, - udaf_expr, udf_expr, virtual_field_expr, AggregateFunctionExprDistinct, - AggregateFunctionExprFun, AliasExprAlias, AllMembersAlias, AllMembersCube, BinaryExprOp, - CastExprDataType, ChangeUserCube, ColumnExprColumn, CubeScanAliasToCube, - CubeScanCanPushdownJoin, CubeScanLimit, CubeScanOffset, CubeScanUngrouped, DimensionName, - JoinLeftOn, JoinRightOn, LikeExprEscapeChar, LikeExprLikeType, LikeExprNegated, LikeType, - LimitFetch, LimitSkip, ListType, LiteralExprValue, LiteralMemberRelation, - LiteralMemberValue, LogicalPlanLanguage, MeasureName, MemberErrorAliasToCube, - MemberErrorError, MemberErrorPriority, MemberPushdownReplacerAliasToCube, - MemberReplacerAliasToCube, ProjectionAlias, SegmentName, TableScanFetch, - TableScanProjection, TableScanSourceTableName, TableScanTableName, TimeDimensionDateRange, - TimeDimensionGranularity, TimeDimensionName, VirtualFieldCube, VirtualFieldName, }, config::ConfigObj, transport::{MetaContext, V1CubeMetaDimensionExt, V1CubeMetaExt, V1CubeMetaMeasureExt}, @@ -123,7 +127,7 @@ impl RewriteRules for MemberRules { ), self.measure_rewrite( "measure-fun", - udaf_expr("?aggr_fun", vec![column_expr("?column")]), + udaf_expr(MEASURE_UDAF_NAME, vec![column_expr("?column")]), Some("?column"), None, None, @@ -646,7 +650,7 @@ impl MemberRules { )); rules.extend(find_matching_old_member( "udaf-fun", - udaf_expr("?fun_name", vec![column_expr("?column")]), + udaf_expr(MEASURE_UDAF_NAME, vec![column_expr("?column")]), )); rules.extend(find_matching_old_member_with_count( "agg-fun-default-count", @@ -1099,7 +1103,7 @@ impl MemberRules { )); rules.push(pushdown_measure_rewrite( "member-pushdown-replacer-udaf-fun", - udaf_expr("?fun_name", vec![column_expr("?column")]), + udaf_expr(MEASURE_UDAF_NAME, vec![column_expr("?column")]), measure_expr("?name", "?old_alias"), None, None, @@ -1108,7 +1112,7 @@ impl MemberRules { )); rules.push(pushdown_measure_rewrite( "member-pushdown-replacer-udaf-fun-on-dimension", - udaf_expr("?fun_name", vec![column_expr("?column")]), + udaf_expr(MEASURE_UDAF_NAME, vec![column_expr("?column")]), dimension_expr("?name", "?old_alias"), None, None, diff --git a/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/aggregate.rs b/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/aggregate.rs index 63155c173f3db..151ef7257f482 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/aggregate.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/aggregate.rs @@ -1,25 +1,31 @@ use crate::{ - compile::rewrite::{ - aggregate, - analysis::LogicalPlanData, - cube_scan_wrapper, grouping_set_expr, original_expr_name, rewrite, - rewriter::{CubeEGraph, CubeRewrite}, - rules::{members::MemberRules, wrapper::WrapperRules}, - subquery, transforming_chain_rewrite, transforming_rewrite, wrapped_select, - wrapped_select_aggr_expr_empty_tail, wrapped_select_filter_expr_empty_tail, - wrapped_select_group_expr_empty_tail, wrapped_select_having_expr_empty_tail, - wrapped_select_joins_empty_tail, wrapped_select_order_expr_empty_tail, - wrapped_select_projection_expr_empty_tail, wrapped_select_subqueries_empty_tail, - wrapped_select_window_expr_empty_tail, wrapper_pullup_replacer, wrapper_pushdown_replacer, - wrapper_replacer_context, AggregateFunctionExprDistinct, AggregateFunctionExprFun, - AliasExprAlias, ColumnExprColumn, ListType, LogicalPlanLanguage, WrappedSelectPushToCube, - WrapperReplacerContextAliasToCube, WrapperReplacerContextPushToCube, + compile::{ + engine::udf::{MEASURE_UDAF_NAME, PATCH_MEASURE_UDAF_NAME}, + rewrite::{ + agg_fun_expr, aggregate, alias_expr, + analysis::ConstantFolding, + binary_expr, case_expr, column_expr, cube_scan_wrapper, grouping_set_expr, + literal_null, original_expr_name, rewrite, + rewriter::{CubeEGraph, CubeRewrite}, + rules::{members::MemberRules, wrapper::WrapperRules}, + subquery, transforming_chain_rewrite, transforming_rewrite, udaf_expr, wrapped_select, + wrapped_select_aggr_expr_empty_tail, wrapped_select_filter_expr_empty_tail, + wrapped_select_group_expr_empty_tail, wrapped_select_having_expr_empty_tail, + wrapped_select_joins_empty_tail, wrapped_select_order_expr_empty_tail, + wrapped_select_projection_expr_empty_tail, wrapped_select_subqueries_empty_tail, + wrapped_select_window_expr_empty_tail, wrapper_pullup_replacer, + wrapper_pushdown_replacer, wrapper_replacer_context, AggregateFunctionExprDistinct, + AggregateFunctionExprFun, AggregateUDFExprFun, AliasExprAlias, ColumnExprColumn, + ListType, LiteralExprValue, LogicalPlanData, LogicalPlanLanguage, + WrappedSelectPushToCube, WrapperReplacerContextAliasToCube, + WrapperReplacerContextPushToCube, + }, }, copy_flag, - transport::V1CubeMetaMeasureExt, + transport::{MetaContext, V1CubeMetaMeasureExt}, var, var_iter, }; -use datafusion::logical_plan::Column; +use datafusion::{logical_plan::Column, scalar::ScalarValue}; use egg::{Subst, Var}; use std::ops::IndexMut; @@ -250,7 +256,7 @@ impl WrapperRules { ), vec![("?aggr_expr", aggr_expr)], wrapper_pullup_replacer( - "?measure", + alias_expr("?out_measure_expr", "?out_measure_alias"), wrapper_replacer_context( "?alias_to_cube", "WrapperReplacerContextPushToCube:true", @@ -267,7 +273,8 @@ impl WrapperRules { distinct, cast_data_type, "?cube_members", - "?measure", + "?out_measure_expr", + "?out_measure_alias", ), ) }, @@ -314,6 +321,91 @@ impl WrapperRules { "GroupingSetExprMembers", ); } + + // incoming structure: agg_fun(?name, case(?cond, (?when_value, measure_column))) + // optional "else null" is fine + // only single when-then + rules.extend(vec![ + transforming_chain_rewrite( + "wrapper-push-down-aggregation-over-filtered-measure", + wrapper_pushdown_replacer("?aggr_expr", "?context"), + vec![ + ( + "?aggr_expr", + agg_fun_expr( + "?fun", + vec![case_expr( + Some("?case_expr".to_string()), + vec![("?literal".to_string(), column_expr("?measure_column"))], + // TODO make `ELSE NULL` optional and/or add generic rewrite to normalize it + Some(literal_null()), + )], + "?distinct", + ), + ), + ( + "?context", + wrapper_replacer_context( + "?alias_to_cube", + "WrapperReplacerContextPushToCube:true", + "?in_projection", + "?cube_members", + "?grouped_subqueries", + "?ungrouped_scan", + ), + ), + ], + alias_expr( + udaf_expr( + PATCH_MEASURE_UDAF_NAME, + vec![ + column_expr("?measure_column"), + "?replace_agg_type".to_string(), + wrapper_pushdown_replacer( + // = is a proper way to filter here: + // CASE NULL WHEN ... will return null + // So NULL in ?case_expr is equivalent to hitting ELSE branch + // TODO add "is not null" to cond? just to make is always boolean + binary_expr("?case_expr", "=", "?literal"), + "?context", + ), + ], + ), + "?out_measure_alias", + ), + self.transform_filtered_measure( + "?aggr_expr", + "?literal", + "?measure_column", + "?fun", + "?cube_members", + "?replace_agg_type", + "?out_measure_alias", + ), + ), + rewrite( + "wrapper-pull-up-aggregation-over-filtered-measure", + udaf_expr( + PATCH_MEASURE_UDAF_NAME, + vec![ + column_expr("?measure_column"), + "?new_agg_type".to_string(), + wrapper_pullup_replacer("?filter_expr", "?context"), + ], + ), + wrapper_pullup_replacer( + udaf_expr( + PATCH_MEASURE_UDAF_NAME, + vec![ + column_expr("?measure_column"), + "?new_agg_type".to_string(), + "?filter_expr".to_string(), + ], + ), + "?context", + ), + ), + ]); } pub fn aggregate_rules_subquery(&self, rules: &mut Vec) { @@ -818,6 +910,203 @@ impl WrapperRules { } } + fn insert_regular_measure( + egraph: &mut CubeEGraph, + subst: &mut Subst, + column: Column, + alias: String, + out_expr_var: Var, + out_alias_var: Var, + ) { + let column_expr_column = egraph.add(LogicalPlanLanguage::ColumnExprColumn( + ColumnExprColumn(column), + )); + let column_expr = egraph.add(LogicalPlanLanguage::ColumnExpr([column_expr_column])); + let udaf_name_expr = egraph.add(LogicalPlanLanguage::AggregateUDFExprFun( + AggregateUDFExprFun(MEASURE_UDAF_NAME.to_string()), + )); + let udaf_args_expr = + egraph.add(LogicalPlanLanguage::AggregateUDFExprArgs(vec![column_expr])); + let udaf_expr = egraph.add(LogicalPlanLanguage::AggregateUDFExpr([ + udaf_name_expr, + udaf_args_expr, + ])); + + subst.insert(out_expr_var, udaf_expr); + + let alias_expr_alias = + egraph.add(LogicalPlanLanguage::AliasExprAlias(AliasExprAlias(alias))); + subst.insert(out_alias_var, alias_expr_alias); + } + + fn insert_patch_measure( + egraph: &mut CubeEGraph, + subst: &mut Subst, + column: Column, + call_agg_type: Option, + alias: String, + out_expr_var: Option, + out_replace_agg_type: Option, + out_alias_var: Var, + ) { + let column_expr_column = egraph.add(LogicalPlanLanguage::ColumnExprColumn( + ColumnExprColumn(column), + )); + let column_expr = egraph.add(LogicalPlanLanguage::ColumnExpr([column_expr_column])); + let new_aggregation_value = match call_agg_type { + Some(call_agg_type) => egraph.add(LogicalPlanLanguage::LiteralExprValue( + LiteralExprValue(ScalarValue::Utf8(Some(call_agg_type))), + )), + None => egraph.add(LogicalPlanLanguage::LiteralExprValue(LiteralExprValue( + ScalarValue::Null, + ))), + }; + let new_aggregation_expr = + egraph.add(LogicalPlanLanguage::LiteralExpr([new_aggregation_value])); + + if let Some(out_replace_agg_type) = out_replace_agg_type { + subst.insert(out_replace_agg_type, new_aggregation_expr); + } + + let add_filters_value = egraph.add(LogicalPlanLanguage::LiteralExprValue( + LiteralExprValue(ScalarValue::Null), + )); + let add_filters_expr = egraph.add(LogicalPlanLanguage::LiteralExpr([add_filters_value])); + let udaf_name_expr = egraph.add(LogicalPlanLanguage::AggregateUDFExprFun( + AggregateUDFExprFun(PATCH_MEASURE_UDAF_NAME.to_string()), + )); + let udaf_args_expr = egraph.add(LogicalPlanLanguage::AggregateUDFExprArgs(vec![ + column_expr, + new_aggregation_expr, + add_filters_expr, + ])); + let udaf_expr = egraph.add(LogicalPlanLanguage::AggregateUDFExpr([ + udaf_name_expr, + udaf_args_expr, + ])); + + if let Some(out_expr_var) = out_expr_var { + subst.insert(out_expr_var, udaf_expr); + } + + let alias_expr_alias = egraph.add(LogicalPlanLanguage::AliasExprAlias(AliasExprAlias( + alias.clone(), + ))); + subst.insert(out_alias_var, alias_expr_alias); + } + + fn pushdown_measure_impl( + egraph: &mut CubeEGraph, + subst: &mut Subst, + original_expr_var: Var, + column_var: Option, + fun_name_var: Option, + distinct_var: Option, + cube_members_var: Var, + out_expr_var: Var, + out_alias_var: Var, + meta: &MetaContext, + disable_strict_agg_type_match: bool, + ) -> bool { + let Some(alias) = original_expr_name(egraph, subst[original_expr_var]) else { + return false; + }; + + for fun in fun_name_var + .map(|fun_var| { + var_iter!(egraph[subst[fun_var]], AggregateFunctionExprFun) + .map(|fun| Some(fun.clone())) + .collect() + }) + .unwrap_or(vec![None]) + { + for distinct in distinct_var + .map(|distinct_var| { + var_iter!(egraph[subst[distinct_var]], AggregateFunctionExprDistinct) + .map(|d| *d) + .collect() + }) + .unwrap_or(vec![false]) + { + let call_agg_type = MemberRules::get_agg_type(fun.as_ref(), distinct); + + let column_iter = if let Some(column_var) = column_var { + var_iter!(egraph[subst[column_var]], ColumnExprColumn) + .cloned() + .collect() + } else { + vec![Column::from_name(MemberRules::default_count_measure_name())] + }; + + if let Some(member_names_to_expr) = &mut egraph + .index_mut(subst[cube_members_var]) + .data + .member_name_to_expr + { + for column in column_iter { + if let Some((&(Some(ref member), _, _), _)) = + LogicalPlanData::do_find_member_by_alias( + member_names_to_expr, + &column.name, + ) + { + if let Some(measure) = meta.find_measure_with_name(member) { + let Some(call_agg_type) = &call_agg_type else { + // call_agg_type is None, rewrite as is + Self::insert_regular_measure( + egraph, + subst, + column, + alias, + out_expr_var, + out_alias_var, + ); + + return true; + }; + + if measure + .is_same_agg_type(call_agg_type, disable_strict_agg_type_match) + { + Self::insert_regular_measure( + egraph, + subst, + column, + alias, + out_expr_var, + out_alias_var, + ); + + return true; + } + + if measure.allow_replace_agg_type( + call_agg_type, + disable_strict_agg_type_match, + ) { + Self::insert_patch_measure( + egraph, + subst, + column, + Some(call_agg_type.clone()), + alias, + Some(out_expr_var), + None, + out_alias_var, + ); + + return true; + } + } + } + } + } + } + } + + false + } + fn pushdown_measure( &self, original_expr_var: &'static str, @@ -827,7 +1116,8 @@ impl WrapperRules { // TODO support cast push downs _cast_data_type_var: Option<&'static str>, cube_members_var: &'static str, - measure_out_var: &'static str, + out_expr_var: &'static str, + out_alias_var: &'static str, ) -> impl Fn(&mut CubeEGraph, &mut Subst) -> bool { let original_expr_var = var!(original_expr_var); let column_var = column_var.map(|v| var!(v)); @@ -835,81 +1125,138 @@ impl WrapperRules { let distinct_var = distinct_var.map(|v| var!(v)); // let cast_data_type_var = cast_data_type_var.map(|v| var!(v)); let cube_members_var = var!(cube_members_var); - let measure_out_var = var!(measure_out_var); + let out_expr_var = var!(out_expr_var); + let out_alias_var = var!(out_alias_var); let meta = self.meta_context.clone(); let disable_strict_agg_type_match = self.config_obj.disable_strict_agg_type_match(); move |egraph, subst| { - if let Some(alias) = original_expr_name(egraph, subst[original_expr_var]) { - for fun in fun_name_var - .map(|fun_var| { - var_iter!(egraph[subst[fun_var]], AggregateFunctionExprFun) - .map(|fun| Some(fun.clone())) - .collect() - }) - .unwrap_or(vec![None]) + Self::pushdown_measure_impl( + egraph, + subst, + original_expr_var, + column_var, + fun_name_var, + distinct_var, + cube_members_var, + out_expr_var, + out_alias_var, + &meta, + disable_strict_agg_type_match, + ) + } + } + + fn transform_filtered_measure( + &self, + aggr_expr_var: &'static str, + literal_var: &'static str, + column_var: &'static str, + fun_name_var: &'static str, + cube_members_var: &'static str, + replace_agg_type_var: &'static str, + out_measure_alias_var: &'static str, + ) -> impl Fn(&mut CubeEGraph, &mut Subst) -> bool { + let aggr_expr_var = var!(aggr_expr_var); + let literal_var = var!(literal_var); + let column_var = var!(column_var); + let fun_name_var = var!(fun_name_var); + let cube_members_var = var!(cube_members_var); + let replace_agg_type_var = var!(replace_agg_type_var); + let out_measure_alias_var = var!(out_measure_alias_var); + + let meta = self.meta_context.clone(); + let disable_strict_agg_type_match = self.config_obj.disable_strict_agg_type_match(); + + move |egraph, subst| { + match &egraph[subst[literal_var]].data.constant { + Some(ConstantFolding::Scalar(_)) => { + // Do nothing + } + _ => { + return false; + } + } + + let Some(alias) = original_expr_name(egraph, subst[aggr_expr_var]) else { + return false; + }; + + for fun in var_iter!(egraph[subst[fun_name_var]], AggregateFunctionExprFun) + .cloned() + .collect::>() + { + let call_agg_type = MemberRules::get_agg_type(Some(&fun), false); + + let column_iter = var_iter!(egraph[subst[column_var]], ColumnExprColumn) + .cloned() + .collect::>(); + + if let Some(member_names_to_expr) = &mut egraph + .index_mut(subst[cube_members_var]) + .data + .member_name_to_expr { - for distinct in distinct_var - .map(|distinct_var| { - var_iter!(egraph[subst[distinct_var]], AggregateFunctionExprDistinct) - .map(|d| *d) - .collect() - }) - .unwrap_or(vec![false]) - { - let call_agg_type = MemberRules::get_agg_type(fun.as_ref(), distinct); - - let column_iter = if let Some(column_var) = column_var { - var_iter!(egraph[subst[column_var]], ColumnExprColumn) - .cloned() - .collect() - } else { - vec![Column::from_name(MemberRules::default_count_measure_name())] - }; - - if let Some(member_names_to_expr) = &mut egraph - .index_mut(subst[cube_members_var]) - .data - .member_name_to_expr + for column in column_iter { + if let Some((&(Some(ref member), _, _), _)) = + LogicalPlanData::do_find_member_by_alias( + member_names_to_expr, + &column.name, + ) { - for column in column_iter { - if let Some((&(Some(ref member), _, _), _)) = - LogicalPlanData::do_find_member_by_alias( - member_names_to_expr, - &column.name, - ) + if let Some(measure) = meta.find_measure_with_name(member) { + if !measure.allow_add_filter(call_agg_type.as_deref()) { + continue; + } + + let Some(call_agg_type) = &call_agg_type else { + // call_agg_type is None, rewrite as is + Self::insert_patch_measure( + egraph, + subst, + column, + None, + alias, + None, + Some(replace_agg_type_var), + out_measure_alias_var, + ); + + return true; + }; + + if measure + .is_same_agg_type(call_agg_type, disable_strict_agg_type_match) { - if let Some(measure) = meta.find_measure_with_name(member) { - if call_agg_type.is_none() - || measure.is_same_agg_type( - call_agg_type.as_ref().unwrap(), - disable_strict_agg_type_match, - ) - { - let column_expr_column = - egraph.add(LogicalPlanLanguage::ColumnExprColumn( - ColumnExprColumn(column.clone()), - )); - - let column_expr = - egraph.add(LogicalPlanLanguage::ColumnExpr([ - column_expr_column, - ])); - let alias_expr_alias = - egraph.add(LogicalPlanLanguage::AliasExprAlias( - AliasExprAlias(alias.clone()), - )); - - let alias_expr = - egraph.add(LogicalPlanLanguage::AliasExpr([ - column_expr, - alias_expr_alias, - ])); - - subst.insert(measure_out_var, alias_expr); - - return true; - } - } + Self::insert_patch_measure( + egraph, + subst, + column, + None, + alias, + None, + Some(replace_agg_type_var), + out_measure_alias_var, + ); + + return true; + } + + if measure.allow_replace_agg_type( + call_agg_type, + disable_strict_agg_type_match, + ) { + Self::insert_patch_measure( + egraph, + subst, + column, + Some(call_agg_type.clone()), + alias, + None, + Some(replace_agg_type_var), + out_measure_alias_var, + ); + + return true; } } } @@ -918,6 +1265,8 @@ impl WrapperRules { } false + + // TODO share code with Self::pushdown_measure: locate cube and measure, check that ?fun matches measure, etc } } } diff --git a/rust/cubesql/cubesql/src/compile/test/test_cube_join_grouped.rs b/rust/cubesql/cubesql/src/compile/test/test_cube_join_grouped.rs index d60c8fb138d9f..04ca392ebc79b 100644 --- a/rust/cubesql/cubesql/src/compile/test/test_cube_join_grouped.rs +++ b/rust/cubesql/cubesql/src/compile/test/test_cube_join_grouped.rs @@ -117,21 +117,21 @@ GROUP BY .find_cube_scan_wrapped_sql() .wrapped_sql .sql - .contains(r#"\"expr\":\"${KibanaSampleDataEcommerce.avgPrice}\""#)); + .contains(r#"\"sql\":\"${KibanaSampleDataEcommerce.avgPrice}\""#)); // Dimension from ungrouped side assert!(query_plan .as_logical_plan() .find_cube_scan_wrapped_sql() .wrapped_sql .sql - .contains(r#"\"expr\":\"${KibanaSampleDataEcommerce.customer_gender}\""#)); + .contains(r#"\"sql\":\"${KibanaSampleDataEcommerce.customer_gender}\""#)); // Dimension from grouped side assert!(query_plan .as_logical_plan() .find_cube_scan_wrapped_sql() .wrapped_sql .sql - .contains(r#"\"expr\":\"\\\"kibana_grouped\\\".\\\"avg_price\\\"\""#)); + .contains(r#"\"sql\":\"\\\"kibana_grouped\\\".\\\"avg_price\\\"\""#)); } /// Simple join between ungrouped and grouped query should plan as a push-to-Cube query @@ -200,21 +200,21 @@ ON ( .find_cube_scan_wrapped_sql() .wrapped_sql .sql - .contains(r#"\"expr\":\"${KibanaSampleDataEcommerce.avgPrice}\""#)); + .contains(r#"\"sql\":\"${KibanaSampleDataEcommerce.avgPrice}\""#)); // Dimension from ungrouped side assert!(query_plan .as_logical_plan() .find_cube_scan_wrapped_sql() .wrapped_sql .sql - .contains(r#"\"expr\":\"${KibanaSampleDataEcommerce.customer_gender}\""#)); + .contains(r#"\"sql\":\"${KibanaSampleDataEcommerce.customer_gender}\""#)); // Dimension from grouped side assert!(query_plan .as_logical_plan() .find_cube_scan_wrapped_sql() .wrapped_sql .sql - .contains(r#"\"expr\":\"\\\"kibana_grouped\\\".\\\"avg_price\\\"\""#)); + .contains(r#"\"sql\":\"\\\"kibana_grouped\\\".\\\"avg_price\\\"\""#)); } /// Join between ungrouped and grouped query with two columns join condition @@ -284,7 +284,7 @@ ON ( .find_cube_scan_wrapped_sql() .wrapped_sql .sql - .contains(r#"\"expr\":\"${KibanaSampleDataEcommerce.avgPrice}\""#)); + .contains(r#"\"sql\":\"${KibanaSampleDataEcommerce.avgPrice}\""#)); } /// Join between ungrouped and grouped query with filter + sort + limit @@ -362,7 +362,7 @@ GROUP BY 1 .find_cube_scan_wrapped_sql() .wrapped_sql .sql - .contains(r#"\"expr\":\"${KibanaSampleDataEcommerce.avgPrice}\""#)); + .contains(r#"\"sql\":\"${KibanaSampleDataEcommerce.avgPrice}\""#)); } #[tokio::test] @@ -451,7 +451,7 @@ LIMIT 1000 assert!(wrapped_sql_node .wrapped_sql .sql - .contains(r#"\"expr\":\"${KibanaSampleDataEcommerce.avgPrice}\""#)); + .contains(r#"\"sql\":\"${KibanaSampleDataEcommerce.avgPrice}\""#)); // Outer sort assert!(wrapped_sql_node @@ -553,7 +553,7 @@ GROUP BY assert!(wrapped_sql_node .wrapped_sql .sql - .contains(r#"\"expr\":\"(CAST(${MultiTypeCube.dim_str1} AS STRING) = $1)\""#)); + .contains(r#"\"sql\":\"(CAST(${MultiTypeCube.dim_str1} AS STRING) = $1)\""#)); // Dimension from top aggregation @@ -564,14 +564,14 @@ GROUP BY assert!(wrapped_sql_node .wrapped_sql .sql - .contains(r#"\"expr\":\"CAST(${MultiTypeCube.dim_str0} AS STRING)\""#)); + .contains(r#"\"sql\":\"CAST(${MultiTypeCube.dim_str0} AS STRING)\""#)); // Measure from top aggregation assert_eq!(wrapped_sql_node.request.measures.as_ref().unwrap().len(), 1); assert!(wrapped_sql_node .wrapped_sql .sql - .contains(r#"\"expr\":\"${MultiTypeCube.countDistinct}\""#)); + .contains(r#"\"sql\":\"${MultiTypeCube.countDistinct}\""#)); } /// Ungrouped-grouped join with complex condition should plan as push-to-Cube query @@ -664,14 +664,14 @@ GROUP BY .find_cube_scan_wrapped_sql() .wrapped_sql .sql - .contains(r#"\"expr\":\"${MultiTypeCube.avgPrice}\""#)); + .contains(r#"\"sql\":\"${MultiTypeCube.avgPrice}\""#)); // Dimension from ungrouped side assert!(query_plan .as_logical_plan() .find_cube_scan_wrapped_sql() .wrapped_sql .sql - .contains(r#"\"expr\":\"${MultiTypeCube.dim_str0}\""#)); + .contains(r#"\"sql\":\"${MultiTypeCube.dim_str0}\""#)); } #[tokio::test] @@ -747,7 +747,7 @@ GROUP BY assert!(wrapped_sql_node .wrapped_sql .sql - .contains(r#"\"expr\":\"(((${KibanaSampleDataEcommerce.notes} >= $1) AND (${KibanaSampleDataEcommerce.notes} <= $2)) OR ((${KibanaSampleDataEcommerce.notes} >= $3) AND (${KibanaSampleDataEcommerce.notes} <= $4)))\""#)); + .contains(r#"\"sql\":\"(((${KibanaSampleDataEcommerce.notes} >= $1) AND (${KibanaSampleDataEcommerce.notes} <= $2)) OR ((${KibanaSampleDataEcommerce.notes} >= $3) AND (${KibanaSampleDataEcommerce.notes} <= $4)))\""#)); // Dimension from top aggregation assert_eq!( @@ -757,14 +757,14 @@ GROUP BY assert!(wrapped_sql_node .wrapped_sql .sql - .contains(r#"\"expr\":\"${KibanaSampleDataEcommerce.customer_gender}\""#)); + .contains(r#"\"sql\":\"${KibanaSampleDataEcommerce.customer_gender}\""#)); // Measure from top aggregation assert_eq!(wrapped_sql_node.request.measures.as_ref().unwrap().len(), 1); assert!(wrapped_sql_node .wrapped_sql .sql - .contains(r#"\"expr\":\"${KibanaSampleDataEcommerce.sumPrice}\""#)); + .contains(r#"\"sql\":\"${KibanaSampleDataEcommerce.sumPrice}\""#)); } #[tokio::test] diff --git a/rust/cubesql/cubesql/src/compile/test/test_filters.rs b/rust/cubesql/cubesql/src/compile/test/test_filters.rs index a400253886f82..affbc71d47647 100644 --- a/rust/cubesql/cubesql/src/compile/test/test_filters.rs +++ b/rust/cubesql/cubesql/src/compile/test/test_filters.rs @@ -97,7 +97,7 @@ async fn test_filter_dim_in_null() { .find_cube_scan_wrapped_sql() .wrapped_sql .sql - .contains(r#"\"expr\":\"${MultiTypeCube.dim_str1} IN (NULL)\""#)); + .contains(r#"\"sql\":\"${MultiTypeCube.dim_str1} IN (NULL)\""#)); } #[tokio::test] @@ -131,5 +131,5 @@ SELECT dim_str0 FROM MultiTypeCube WHERE (dim_str1 IS NULL OR dim_str1 IN (NULL) .find_cube_scan_wrapped_sql() .wrapped_sql .sql - .contains(r#"\"expr\":\"((${MultiTypeCube.dim_str1} IS NULL) OR (${MultiTypeCube.dim_str1} IN (NULL) AND FALSE))\""#)); + .contains(r#"\"sql\":\"((${MultiTypeCube.dim_str1} IS NULL) OR (${MultiTypeCube.dim_str1} IN (NULL) AND FALSE))\""#)); } diff --git a/rust/cubesql/cubesql/src/compile/test/test_wrapper.rs b/rust/cubesql/cubesql/src/compile/test/test_wrapper.rs index 0973f7d374261..41dc09e54b3a3 100644 --- a/rust/cubesql/cubesql/src/compile/test/test_wrapper.rs +++ b/rust/cubesql/cubesql/src/compile/test/test_wrapper.rs @@ -925,9 +925,9 @@ async fn test_case_wrapper_with_system_fields() { .wrapped_sql .sql .contains( - "\\\"cube_name\\\":\\\"KibanaSampleDataEcommerce\\\",\\\"alias\\\":\\\"user\\\"" + "\\\"cubeName\\\":\\\"KibanaSampleDataEcommerce\\\",\\\"alias\\\":\\\"user\\\"" ), - r#"SQL contains `\"cube_name\":\"KibanaSampleDataEcommerce\",\"alias\":\"user\"` {}"#, + r#"SQL contains `\"cubeName\":\"KibanaSampleDataEcommerce\",\"alias\":\"user\"` {}"#, logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql ); @@ -1259,29 +1259,38 @@ async fn test_wrapper_filter_flatten() { .request, TransportLoadRequestQuery { measures: Some(vec![json!({ - "cube_name": "KibanaSampleDataEcommerce", + "cubeName": "KibanaSampleDataEcommerce", "alias": "sum_kibanasample", - "cube_params": ["KibanaSampleDataEcommerce"], - // This is grouped query, KibanaSampleDataEcommerce.sumPrice is correct in this context - // SUM(sumPrice) will be incrrect here, it would lead to SUM(SUM(sql)) in generated query - "expr": "${KibanaSampleDataEcommerce.sumPrice}", - "grouping_set": null, + "expr": { + "type": "SqlFunction", + "cubeParams": ["KibanaSampleDataEcommerce"], + // This is grouped query, KibanaSampleDataEcommerce.sumPrice is correct in this context + // SUM(sumPrice) will be incrrect here, it would lead to SUM(SUM(sql)) in generated query + "sql": "${KibanaSampleDataEcommerce.sumPrice}", + }, + "groupingSet": null, }) .to_string(),]), dimensions: Some(vec![json!({ - "cube_name": "KibanaSampleDataEcommerce", + "cubeName": "KibanaSampleDataEcommerce", "alias": "customer_gender", - "cube_params": ["KibanaSampleDataEcommerce"], - "expr": "${KibanaSampleDataEcommerce.customer_gender}", - "grouping_set": null, + "expr": { + "type": "SqlFunction", + "cubeParams": ["KibanaSampleDataEcommerce"], + "sql": "${KibanaSampleDataEcommerce.customer_gender}", + }, + "groupingSet": null, }) .to_string(),]), segments: Some(vec![json!({ - "cube_name": "KibanaSampleDataEcommerce", + "cubeName": "KibanaSampleDataEcommerce", "alias": "lower_kibanasamp", - "cube_params": ["KibanaSampleDataEcommerce"], - "expr": "(LOWER(${KibanaSampleDataEcommerce.customer_gender}) = $0$)", - "grouping_set": null, + "expr": { + "type": "SqlFunction", + "cubeParams": ["KibanaSampleDataEcommerce"], + "sql": "(LOWER(${KibanaSampleDataEcommerce.customer_gender}) = $0$)", + }, + "groupingSet": null, }) .to_string(),]), time_dimensions: None, @@ -1684,7 +1693,7 @@ async fn select_agg_where_false() { // Final query uses grouped query to Cube.js with WHERE FALSE, but without LIMIT 0 assert!(!sql.contains("\"ungrouped\":")); - assert!(sql.contains(r#"\"expr\":\"FALSE\""#)); + assert!(sql.contains(r#"\"sql\":\"FALSE\""#)); assert!(sql.contains(r#""limit": 50000"#)); } @@ -1735,7 +1744,161 @@ async fn wrapper_dimension_agg_where_false() { // Final query uses grouped query to Cube.js with WHERE FALSE, but without LIMIT 0 assert!(!sql.contains("\"ungrouped\":")); - assert!(sql.contains(r#"\"expr\":\"FALSE\""#)); + assert!(sql.contains(r#"\"sql\":\"FALSE\""#)); assert!(!sql.contains(r#""limit""#)); assert!(sql.contains("LIMIT 50000")); } + +/// MIN(avg_measure) should get pushed to Cube with replaced measure +#[tokio::test] +async fn wrapper_min_from_avg_measure() { + if !Rewriter::sql_push_down_enabled() { + return; + } + init_testing_logger(); + + let query_plan = convert_select_to_query_plan( + // language=PostgreSQL + r#" + SELECT + MIN(avgPrice) + FROM + MultiTypeCube + "# + .to_string(), + DatabaseProtocol::PostgreSQL, + ) + .await; + + let physical_plan = query_plan.as_physical_plan().await.unwrap(); + println!( + "Physical plan: {}", + displayable(physical_plan.as_ref()).indent() + ); + + assert_eq!( + query_plan + .as_logical_plan() + .find_cube_scan_wrapped_sql() + .request, + TransportLoadRequestQuery { + measures: Some(vec![json!({ + "cubeName": "MultiTypeCube", + "alias": "min_multitypecub", + "expr": { + "type": "PatchMeasure", + "sourceMeasure": "MultiTypeCube.avgPrice", + "replaceAggregationType": "min", + "addFilters": [], + }, + "groupingSet": null, + }) + .to_string(),]), + dimensions: Some(vec![]), + segments: Some(vec![]), + order: Some(vec![]), + ..Default::default() + } + ); +} + +#[tokio::test] +async fn test_ad_hoc_measure_filter() { + if !Rewriter::sql_push_down_enabled() { + return; + } + init_testing_logger(); + + let query_plan = convert_select_to_query_plan( + // language=PostgreSQL + r#"SELECT + dim_str0, + AVG( + CASE ( + ( + CAST(TRUNC(EXTRACT(YEAR FROM dim_date0)) AS INTEGER) = 2024 + ) + AND + ( + CAST(TRUNC(EXTRACT(MONTH FROM dim_date0)) AS INTEGER) <= 11 + ) + ) + WHEN TRUE + THEN avgPrice + ELSE NULL + END + ), + SUM( + CASE (dim_str1 = 'foo') + WHEN TRUE + THEN maxPrice + ELSE NULL + END + ) +FROM MultiTypeCube +GROUP BY + 1 +;"# + .to_string(), + DatabaseProtocol::PostgreSQL, + ) + .await; + + let physical_plan = query_plan.as_physical_plan().await.unwrap(); + println!( + "Physical plan: {}", + displayable(physical_plan.as_ref()).indent() + ); + + assert_eq!( + query_plan + .as_logical_plan() + .find_cube_scan_wrapped_sql() + .request, + TransportLoadRequestQuery { + measures: Some(vec![ + json!({ + "cubeName": "MultiTypeCube", + "alias": "avg_case_cast_tr", + "expr": { + "type": "PatchMeasure", + "sourceMeasure": "MultiTypeCube.avgPrice", + "replaceAggregationType": null, + "addFilters": [{ + "cubeParams": ["MultiTypeCube"], + "sql": "(((CAST(TRUNC(EXTRACT(YEAR FROM ${MultiTypeCube.dim_date0})) AS INTEGER) = 2024) AND (CAST(TRUNC(EXTRACT(MONTH FROM ${MultiTypeCube.dim_date0})) AS INTEGER) <= 11)) = TRUE)" + }], + }, + "groupingSet": null, + }).to_string(), + json!({ + "cubeName": "MultiTypeCube", + "alias": "sum_case_multity", + "expr": { + "type": "PatchMeasure", + "sourceMeasure": "MultiTypeCube.maxPrice", + "replaceAggregationType": "sum", + "addFilters": [{ + "cubeParams": ["MultiTypeCube"], + "sql": "((${MultiTypeCube.dim_str1} = $0$) = TRUE)" + }], + }, + "groupingSet": null, + }).to_string(), + ]), + dimensions: Some(vec![json!({ + "cubeName": "MultiTypeCube", + "alias": "dim_str0", + "expr": { + "type": "SqlFunction", + "cubeParams": ["MultiTypeCube"], + "sql": "${MultiTypeCube.dim_str0}", + }, + "groupingSet": null, + }).to_string(),]), + segments: Some(vec![]), + order: Some(vec![]), + ..Default::default() + } + ); +} diff --git a/rust/cubesql/cubesql/src/transport/ext.rs b/rust/cubesql/cubesql/src/transport/ext.rs index f9b35687a16ef..2f9a9609064e2 100644 --- a/rust/cubesql/cubesql/src/transport/ext.rs +++ b/rust/cubesql/cubesql/src/transport/ext.rs @@ -10,6 +10,10 @@ pub trait V1CubeMetaMeasureExt { fn is_same_agg_type(&self, expect_agg_type: &str, disable_strict_match: bool) -> bool; + fn allow_replace_agg_type(&self, query_agg_type: &str, disable_strict_match: bool) -> bool; + + fn allow_add_filter(&self, query_agg_type: Option<&str>) -> bool; + fn get_sql_type(&self) -> ColumnType; } @@ -24,36 +28,73 @@ impl V1CubeMetaMeasureExt for CubeMetaMeasure { if disable_strict_match { return true; } - if self.agg_type.is_some() { - if expect_agg_type.eq(&"countDistinct".to_string()) { - let agg_type = self.agg_type.as_ref().unwrap(); - - agg_type.eq(&"countDistinct".to_string()) - || agg_type.eq(&"countDistinctApprox".to_string()) - || agg_type.eq(&"number".to_string()) - } else if expect_agg_type.eq(&"sum".to_string()) { - let agg_type = self.agg_type.as_ref().unwrap(); - - agg_type.eq(&"sum".to_string()) - || agg_type.eq(&"count".to_string()) - || agg_type.eq(&"number".to_string()) - } else if expect_agg_type.eq(&"min".to_string()) - || expect_agg_type.eq(&"max".to_string()) - { - let agg_type = self.agg_type.as_ref().unwrap(); - - agg_type.eq(&"number".to_string()) - || agg_type.eq(&"string".to_string()) - || agg_type.eq(&"time".to_string()) - || agg_type.eq(&"boolean".to_string()) - || agg_type.eq(expect_agg_type) - } else { - let agg_type = self.agg_type.as_ref().unwrap(); - - agg_type.eq(&"number".to_string()) || agg_type.eq(expect_agg_type) + let Some(agg_type) = &self.agg_type else { + return false; + }; + match expect_agg_type { + "countDistinct" => { + agg_type == "countDistinct" + || agg_type == "countDistinctApprox" + || agg_type == "number" } - } else { - false + "sum" => agg_type == "sum" || agg_type == "count" || agg_type == "number", + "min" | "max" => { + agg_type == "number" + || agg_type == "string" + || agg_type == "time" + || agg_type == "boolean" + || agg_type == expect_agg_type + } + _ => agg_type == "number" || agg_type == expect_agg_type, + } + } + + // This should be aligned with BaseMeasure.preparePatchedMeasure + // See packages/cubejs-schema-compiler/src/adapter/BaseMeasure.ts:16 + fn allow_replace_agg_type(&self, query_agg_type: &str, disable_strict_match: bool) -> bool { + if disable_strict_match { + return true; + } + let Some(agg_type) = &self.agg_type else { + return false; + }; + + match (agg_type.as_str(), query_agg_type) { + ( + "sum" | "avg" | "min" | "max", + "sum" | "avg" | "min" | "max" | "count_distinct" | "count_distinct_approx", + ) => true, + + ( + "count_distinct" | "count_distinct_approx", + "count_distinct" | "count_distinct_approx", + ) => true, + + _ => false, + } + } + + // This should be aligned with BaseMeasure.preparePatchedMeasure + // See packages/cubejs-schema-compiler/src/adapter/BaseMeasure.ts:16 + fn allow_add_filter(&self, query_agg_type: Option<&str>) -> bool { + let Some(agg_type) = &self.agg_type else { + return false; + }; + + let agg_type = match query_agg_type { + Some(query_agg_type) => query_agg_type, + None => agg_type, + }; + + match agg_type { + "sum" + | "avg" + | "min" + | "max" + | "count" + | "count_distinct" + | "count_distinct_approx" => true, + _ => false, } } diff --git a/yarn.lock b/yarn.lock index fc962405e7d51..231c70b568cb4 100644 --- a/yarn.lock +++ b/yarn.lock @@ -11678,6 +11678,11 @@ asn1@^0.2.6, asn1@~0.2.3: dependencies: safer-buffer "~2.1.0" +assert-never@^1.4.0: + version "1.4.0" + resolved "https://registry.yarnpkg.com/assert-never/-/assert-never-1.4.0.tgz#b0d4988628c87f35eb94716cc54422a63927e175" + integrity sha512-5oJg84os6NMQNl27T9LnZkvvqzvAnHu03ShCnoj6bsJwS7L8AO4lf+C/XjK/nvzEqQB744moC6V128RucQd1jA== + assert-options@0.7.0: version "0.7.0" resolved "https://registry.yarnpkg.com/assert-options/-/assert-options-0.7.0.tgz#82c27618d9c0baa5e9da8ef607ee261a44ed6e5e"