diff --git a/lib/DBSQLClient.ts b/lib/DBSQLClient.ts index 8d981fd1..0d2dc479 100644 --- a/lib/DBSQLClient.ts +++ b/lib/DBSQLClient.ts @@ -227,6 +227,7 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I const session = new DBSQLSession({ handle: definedOrError(response.sessionHandle), context: this, + serverProtocolVersion: response.serverProtocolVersion, }); this.sessions.add(session); return session; diff --git a/lib/DBSQLSession.ts b/lib/DBSQLSession.ts index 009a8c8b..13146cdc 100644 --- a/lib/DBSQLSession.ts +++ b/lib/DBSQLSession.ts @@ -12,6 +12,8 @@ import { TSparkDirectResults, TSparkArrowTypes, TSparkParameter, + TProtocolVersion, + TExecuteStatementReq, } from '../thrift/TCLIService_types'; import IDBSQLSession, { ExecuteStatementOptions, @@ -29,7 +31,7 @@ import IOperation from './contracts/IOperation'; import DBSQLOperation from './DBSQLOperation'; import Status from './dto/Status'; import InfoValue from './dto/InfoValue'; -import { definedOrError, LZ4 } from './utils'; +import { definedOrError, LZ4, ProtocolVersion } from './utils'; import CloseableCollection from './utils/CloseableCollection'; import { LogLevel } from './contracts/IDBSQLLogger'; import HiveDriverError from './errors/HiveDriverError'; @@ -74,13 +76,16 @@ function getDirectResultsOptions(maxRows: number | bigint | Int64 | null | undef }; } -function getArrowOptions(config: ClientConfig): { +function getArrowOptions( + config: ClientConfig, + serverProtocolVersion: TProtocolVersion | undefined | null, +): { canReadArrowResult: boolean; useArrowNativeTypes?: TSparkArrowTypes; } { const { arrowEnabled = true, useArrowNativeTypes = true } = config; - if (!arrowEnabled) { + if (!arrowEnabled || !ProtocolVersion.supportsArrowMetadata(serverProtocolVersion)) { return { canReadArrowResult: false, }; @@ -136,6 +141,7 @@ function getQueryParameters( interface DBSQLSessionConstructorOptions { handle: TSessionHandle; context: IClientContext; + serverProtocolVersion?: TProtocolVersion; } export default class DBSQLSession implements IDBSQLSession { @@ -145,14 +151,28 @@ export default class DBSQLSession implements IDBSQLSession { private isOpen = true; + private serverProtocolVersion?: TProtocolVersion; + public onClose?: () => void; private operations = new CloseableCollection(); - constructor({ handle, context }: DBSQLSessionConstructorOptions) { + /** + * Helper method to determine if runAsync should be set for metadata operations + * @private + * @returns true if supported by protocol version, undefined otherwise + */ + private getRunAsyncForMetadataOperations(): boolean | undefined { + return ProtocolVersion.supportsAsyncMetadataOperations(this.serverProtocolVersion) ? true : undefined; + } + + constructor({ handle, context, serverProtocolVersion }: DBSQLSessionConstructorOptions) { this.sessionHandle = handle; this.context = context; + // Get the server protocol version from the provided parameter (from TOpenSessionResp) + this.serverProtocolVersion = serverProtocolVersion; this.context.getLogger().log(LogLevel.debug, `Session created with id: ${this.id}`); + this.context.getLogger().log(LogLevel.debug, `Server protocol version: ${this.serverProtocolVersion}`); } public get id() { @@ -193,17 +213,29 @@ export default class DBSQLSession implements IDBSQLSession { await this.failIfClosed(); const driver = await this.context.getDriver(); const clientConfig = this.context.getConfig(); - const operationPromise = driver.executeStatement({ + + const request = new TExecuteStatementReq({ sessionHandle: this.sessionHandle, statement, queryTimeout: options.queryTimeout ? numberToInt64(options.queryTimeout) : undefined, runAsync: true, ...getDirectResultsOptions(options.maxRows, clientConfig), - ...getArrowOptions(clientConfig), - canDownloadResult: options.useCloudFetch ?? clientConfig.useCloudFetch, - parameters: getQueryParameters(options.namedParameters, options.ordinalParameters), - canDecompressLZ4Result: (options.useLZ4Compression ?? clientConfig.useLZ4Compression) && Boolean(LZ4), + ...getArrowOptions(clientConfig, this.serverProtocolVersion), }); + + if (ProtocolVersion.supportsParameterizedQueries(this.serverProtocolVersion)) { + request.parameters = getQueryParameters(options.namedParameters, options.ordinalParameters); + } + + if (ProtocolVersion.supportsArrowCompression(this.serverProtocolVersion)) { + request.canDecompressLZ4Result = (options.useLZ4Compression ?? clientConfig.useLZ4Compression) && Boolean(LZ4); + } + + if (ProtocolVersion.supportsCloudFetch(this.serverProtocolVersion)) { + request.canDownloadResult = options.useCloudFetch ?? clientConfig.useCloudFetch; + } + + const operationPromise = driver.executeStatement(request); const response = await this.handleResponse(operationPromise); const operation = this.createOperation(response); @@ -352,9 +384,10 @@ export default class DBSQLSession implements IDBSQLSession { await this.failIfClosed(); const driver = await this.context.getDriver(); const clientConfig = this.context.getConfig(); + const operationPromise = driver.getTypeInfo({ sessionHandle: this.sessionHandle, - runAsync: true, + runAsync: this.getRunAsyncForMetadataOperations(), ...getDirectResultsOptions(request.maxRows, clientConfig), }); const response = await this.handleResponse(operationPromise); @@ -371,9 +404,10 @@ export default class DBSQLSession implements IDBSQLSession { await this.failIfClosed(); const driver = await this.context.getDriver(); const clientConfig = this.context.getConfig(); + const operationPromise = driver.getCatalogs({ sessionHandle: this.sessionHandle, - runAsync: true, + runAsync: this.getRunAsyncForMetadataOperations(), ...getDirectResultsOptions(request.maxRows, clientConfig), }); const response = await this.handleResponse(operationPromise); @@ -390,11 +424,12 @@ export default class DBSQLSession implements IDBSQLSession { await this.failIfClosed(); const driver = await this.context.getDriver(); const clientConfig = this.context.getConfig(); + const operationPromise = driver.getSchemas({ sessionHandle: this.sessionHandle, catalogName: request.catalogName, schemaName: request.schemaName, - runAsync: true, + runAsync: this.getRunAsyncForMetadataOperations(), ...getDirectResultsOptions(request.maxRows, clientConfig), }); const response = await this.handleResponse(operationPromise); @@ -411,13 +446,14 @@ export default class DBSQLSession implements IDBSQLSession { await this.failIfClosed(); const driver = await this.context.getDriver(); const clientConfig = this.context.getConfig(); + const operationPromise = driver.getTables({ sessionHandle: this.sessionHandle, catalogName: request.catalogName, schemaName: request.schemaName, tableName: request.tableName, tableTypes: request.tableTypes, - runAsync: true, + runAsync: this.getRunAsyncForMetadataOperations(), ...getDirectResultsOptions(request.maxRows, clientConfig), }); const response = await this.handleResponse(operationPromise); @@ -434,9 +470,10 @@ export default class DBSQLSession implements IDBSQLSession { await this.failIfClosed(); const driver = await this.context.getDriver(); const clientConfig = this.context.getConfig(); + const operationPromise = driver.getTableTypes({ sessionHandle: this.sessionHandle, - runAsync: true, + runAsync: this.getRunAsyncForMetadataOperations(), ...getDirectResultsOptions(request.maxRows, clientConfig), }); const response = await this.handleResponse(operationPromise); @@ -453,13 +490,14 @@ export default class DBSQLSession implements IDBSQLSession { await this.failIfClosed(); const driver = await this.context.getDriver(); const clientConfig = this.context.getConfig(); + const operationPromise = driver.getColumns({ sessionHandle: this.sessionHandle, catalogName: request.catalogName, schemaName: request.schemaName, tableName: request.tableName, columnName: request.columnName, - runAsync: true, + runAsync: this.getRunAsyncForMetadataOperations(), ...getDirectResultsOptions(request.maxRows, clientConfig), }); const response = await this.handleResponse(operationPromise); @@ -476,12 +514,13 @@ export default class DBSQLSession implements IDBSQLSession { await this.failIfClosed(); const driver = await this.context.getDriver(); const clientConfig = this.context.getConfig(); + const operationPromise = driver.getFunctions({ sessionHandle: this.sessionHandle, catalogName: request.catalogName, schemaName: request.schemaName, functionName: request.functionName, - runAsync: true, + runAsync: this.getRunAsyncForMetadataOperations(), ...getDirectResultsOptions(request.maxRows, clientConfig), }); const response = await this.handleResponse(operationPromise); @@ -492,12 +531,13 @@ export default class DBSQLSession implements IDBSQLSession { await this.failIfClosed(); const driver = await this.context.getDriver(); const clientConfig = this.context.getConfig(); + const operationPromise = driver.getPrimaryKeys({ sessionHandle: this.sessionHandle, catalogName: request.catalogName, schemaName: request.schemaName, tableName: request.tableName, - runAsync: true, + runAsync: this.getRunAsyncForMetadataOperations(), ...getDirectResultsOptions(request.maxRows, clientConfig), }); const response = await this.handleResponse(operationPromise); @@ -514,6 +554,7 @@ export default class DBSQLSession implements IDBSQLSession { await this.failIfClosed(); const driver = await this.context.getDriver(); const clientConfig = this.context.getConfig(); + const operationPromise = driver.getCrossReference({ sessionHandle: this.sessionHandle, parentCatalogName: request.parentCatalogName, @@ -522,7 +563,7 @@ export default class DBSQLSession implements IDBSQLSession { foreignCatalogName: request.foreignCatalogName, foreignSchemaName: request.foreignSchemaName, foreignTableName: request.foreignTableName, - runAsync: true, + runAsync: this.getRunAsyncForMetadataOperations(), ...getDirectResultsOptions(request.maxRows, clientConfig), }); const response = await this.handleResponse(operationPromise); diff --git a/lib/utils/index.ts b/lib/utils/index.ts index 963f6b05..b8203c4d 100644 --- a/lib/utils/index.ts +++ b/lib/utils/index.ts @@ -2,5 +2,6 @@ import definedOrError from './definedOrError'; import buildUserAgentString from './buildUserAgentString'; import formatProgress, { ProgressUpdateTransformer } from './formatProgress'; import LZ4 from './lz4'; +import * as ProtocolVersion from './protocolVersion'; -export { definedOrError, buildUserAgentString, formatProgress, ProgressUpdateTransformer, LZ4 }; +export { definedOrError, buildUserAgentString, formatProgress, ProgressUpdateTransformer, LZ4, ProtocolVersion }; diff --git a/lib/utils/protocolVersion.ts b/lib/utils/protocolVersion.ts new file mode 100644 index 00000000..171cfa1a --- /dev/null +++ b/lib/utils/protocolVersion.ts @@ -0,0 +1,95 @@ +import { TProtocolVersion } from '../../thrift/TCLIService_types'; + +/** + * Protocol version information from Thrift TCLIService + * Each version adds certain features to the Spark/Hive API + * + * Databricks only supports SPARK_CLI_SERVICE_PROTOCOL_V1 (0xA501) or higher + */ + +/** + * Check if the current protocol version supports a specific feature + * @param serverProtocolVersion The protocol version received from server in TOpenSessionResp + * @param requiredVersion The minimum protocol version required for a feature + * @returns boolean indicating if the feature is supported + */ +export function isFeatureSupported( + serverProtocolVersion: TProtocolVersion | undefined | null, + requiredVersion: TProtocolVersion, +): boolean { + if (serverProtocolVersion === undefined || serverProtocolVersion === null) { + return false; + } + + return serverProtocolVersion >= requiredVersion; +} + +/** + * Check if parameterized queries are supported + * (Requires SPARK_CLI_SERVICE_PROTOCOL_V8 or higher) + * @param serverProtocolVersion The protocol version from server + * @returns boolean indicating if parameterized queries are supported + */ +export function supportsParameterizedQueries(serverProtocolVersion: TProtocolVersion | undefined | null): boolean { + return isFeatureSupported(serverProtocolVersion, TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8); +} + +/** + * Check if async metadata operations are supported + * (Requires SPARK_CLI_SERVICE_PROTOCOL_V6 or higher) + * @param serverProtocolVersion The protocol version from server + * @returns boolean indicating if async metadata operations are supported + */ +export function supportsAsyncMetadataOperations(serverProtocolVersion: TProtocolVersion | undefined | null): boolean { + return isFeatureSupported(serverProtocolVersion, TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V6); +} + +/** + * Check if result persistence mode is supported + * (Requires SPARK_CLI_SERVICE_PROTOCOL_V7 or higher) + * @param serverProtocolVersion The protocol version from server + * @returns boolean indicating if result persistence mode is supported + */ +export function supportsResultPersistenceMode(serverProtocolVersion: TProtocolVersion | undefined | null): boolean { + return isFeatureSupported(serverProtocolVersion, TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V7); +} + +/** + * Check if Arrow compression is supported + * (Requires SPARK_CLI_SERVICE_PROTOCOL_V6 or higher) + * @param serverProtocolVersion The protocol version from server + * @returns boolean indicating if compressed Arrow batches are supported + */ +export function supportsArrowCompression(serverProtocolVersion: TProtocolVersion | undefined | null): boolean { + return isFeatureSupported(serverProtocolVersion, TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V6); +} + +/** + * Check if Arrow metadata is supported + * (Requires SPARK_CLI_SERVICE_PROTOCOL_V5 or higher) + * @param serverProtocolVersion The protocol version from server + * @returns boolean indicating if Arrow metadata is supported + */ +export function supportsArrowMetadata(serverProtocolVersion: TProtocolVersion | undefined | null): boolean { + return isFeatureSupported(serverProtocolVersion, TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V5); +} + +/** + * Check if multiple catalogs are supported + * (Requires SPARK_CLI_SERVICE_PROTOCOL_V4 or higher) + * @param serverProtocolVersion The protocol version from server + * @returns boolean indicating if multiple catalogs are supported + */ +export function supportsMultipleCatalogs(serverProtocolVersion: TProtocolVersion | undefined | null): boolean { + return isFeatureSupported(serverProtocolVersion, TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V4); +} + +/** + * Check if cloud object storage fetching is supported + * (Requires SPARK_CLI_SERVICE_PROTOCOL_V3 or higher) + * @param serverProtocolVersion The protocol version from server + * @returns boolean indicating if cloud fetching is supported + */ +export function supportsCloudFetch(serverProtocolVersion: TProtocolVersion | undefined | null): boolean { + return isFeatureSupported(serverProtocolVersion, TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V3); +} diff --git a/tests/e2e/protocol_versions.test.ts b/tests/e2e/protocol_versions.test.ts new file mode 100644 index 00000000..d89fb04e --- /dev/null +++ b/tests/e2e/protocol_versions.test.ts @@ -0,0 +1,363 @@ +/* eslint-disable func-style, no-loop-func */ +import { expect } from 'chai'; +import sinon from 'sinon'; +import Int64 from 'node-int64'; +import { DBSQLClient } from '../../lib'; +import IDBSQLSession from '../../lib/contracts/IDBSQLSession'; +import { TProtocolVersion } from '../../thrift/TCLIService_types'; +import config from './utils/config'; +import IDriver from '../../lib/contracts/IDriver'; + +// Create a list of all SPARK protocol versions +const protocolVersions = [ + { version: TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V1, desc: 'V1: no special features' }, + { version: TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V2, desc: 'V2: no special features' }, + { version: TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V3, desc: 'V3: cloud fetch' }, + { version: TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V4, desc: 'V4: multiple catalogs' }, + { version: TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V5, desc: 'V5: arrow metadata' }, + { version: TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V6, desc: 'V6: async metadata, arrow compression' }, + { version: TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V7, desc: 'V7: result persistence mode' }, + { version: TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8, desc: 'V8: parameterized queries' }, +]; + +/** + * Execute a statement and return results + */ +async function execute(session: IDBSQLSession, statement: string) { + const operation = await session.executeStatement(statement); + const result = await operation.fetchAll(); + await operation.close(); + return result; +} + +describe('Protocol Versions E2E Tests', function () { + // These tests might take longer than the default timeout + this.timeout(60000); + + // Instead of using a loop with functions inside, we'll create a function that returns + // a test suite for each protocol version + protocolVersions.forEach(({ version, desc }) => { + describe(`Protocol ${desc}`, () => { + let client: DBSQLClient; + let session: IDBSQLSession; + + before(function (this: Mocha.Context) { + return (async () => { + try { + client = new DBSQLClient(); + + // Connect to the Databricks SQL service + await client.connect({ + host: config.host, + path: config.path, + token: config.token, + }); + + // Get access to the driver + const getDriverOriginal = client.getDriver.bind(client); + + // Stub getDriver to return a proxied version of the driver with overridden openSession + sinon.stub(client, 'getDriver').callsFake(async () => { + const driver = await getDriverOriginal(); + + // Create a proxy for the driver to intercept openSession calls + const driverProxy = new Proxy(driver, { + get(target, prop) { + if (prop === 'openSession') { + return async (request: any) => { + // Modify the request to use our specific protocol version + const modifiedRequest = { + ...request, + client_protocol_i64: new Int64(version), + }; + return target.openSession(modifiedRequest); + }; + } + return target[prop as keyof IDriver]; + }, + }); + + return driverProxy; + }); + + session = await client.openSession({ + initialCatalog: config.catalog, + initialSchema: config.schema, + }); + } catch (error) { + // eslint-disable-next-line no-console + console.log(`Failed to open session with protocol version ${desc}: ${error}`); + this.skip(); + } + })(); + }); + + after(async () => { + if (session) { + await session.close(); + } + if (client) { + await client.close(); + } + // Restore sinon stubs + sinon.restore(); + }); + + it('should handle various data types', async () => { + // Query testing multiple data types supported by Databricks + const query = ` + SELECT + -- Numeric types + CAST(42 AS TINYINT) AS tiny_int_val, + CAST(1000 AS SMALLINT) AS small_int_val, + CAST(100000 AS INT) AS int_val, + CAST(123456789012345 AS BIGINT) AS bigint_val, -- Using a smaller BIGINT value within JavaScript safe range + CAST(3.14 AS FLOAT) AS float_val, + CAST(3.14159265359 AS DOUBLE) AS double_val, + CAST(123.45 AS DECIMAL(5,2)) AS decimal_val, + + -- String and Binary types + CAST('hello world' AS STRING) AS string_val, + CAST(X'68656C6C6F' AS BINARY) AS binary_val, -- 'hello' in hex + + -- Boolean type + CAST(TRUE AS BOOLEAN) AS boolean_val, + + -- Date and Time types - Use current_date() to ensure consistency with server time zone + current_date() AS date_val, + current_timestamp() AS timestamp_val, + + -- Intervals + INTERVAL '1' DAY AS interval_day, + + -- Complex types + ARRAY(1, 2, 3) AS array_val, + MAP('a', 1, 'b', 2, 'c', 3) AS map_val, + STRUCT(42 AS id, 'test_name' AS name, TRUE AS active) AS struct_val, + + -- Null value + CAST(NULL AS STRING) AS null_val + `; + + const result = await execute(session, query); + expect(result).to.be.an('array'); + expect(result.length).to.equal(1); + + const row = result[0] as any; + + // Test numeric types + expect(row).to.have.property('tiny_int_val'); + expect(row.tiny_int_val).to.equal(42); + + expect(row).to.have.property('small_int_val'); + expect(row.small_int_val).to.equal(1000); + + expect(row).to.have.property('int_val'); + expect(row.int_val).to.equal(100000); + + expect(row).to.have.property('bigint_val'); + // Using a smaller bigint value that can be safely represented in JavaScript + expect(Number(row.bigint_val)).to.equal(123456789012345); + + expect(row).to.have.property('float_val'); + expect(row.float_val).to.be.closeTo(3.14, 0.001); // Allow small precision differences + + expect(row).to.have.property('double_val'); + expect(row.double_val).to.be.closeTo(3.14159265359, 0.00000000001); + + expect(row).to.have.property('decimal_val'); + expect(parseFloat(row.decimal_val)).to.be.closeTo(123.45, 0.001); + + // Test string and binary types + expect(row).to.have.property('string_val'); + expect(row.string_val).to.equal('hello world'); + + expect(row).to.have.property('binary_val'); + // Binary might be returned in different formats depending on protocol version + + // Test boolean type + expect(row).to.have.property('boolean_val'); + expect(row.boolean_val).to.be.true; + + // Test date type + expect(row).to.have.property('date_val'); + // Date may be returned as a Date object, string, or other format depending on protocol version + const dateVal = row.date_val; + + if (dateVal instanceof Date) { + // If it's a Date object, just verify it's a valid date in approximately the right range + expect(dateVal.getFullYear()).to.be.at.least(2023); + expect(dateVal).to.be.an.instanceof(Date); + } else if (typeof dateVal === 'string') { + // If it's a string, verify it contains a date-like format + expect(/\d{4}[-/]\d{1,2}[-/]\d{1,2}/.test(dateVal) || /\d{1,2}[-/]\d{1,2}[-/]\d{4}/.test(dateVal)).to.be.true; + } else { + // Otherwise just make sure it exists + expect(dateVal).to.exist; + } + + // Test timestamp type + expect(row).to.have.property('timestamp_val'); + const timestampVal = row.timestamp_val; + + if (timestampVal instanceof Date) { + // If it's a Date object, verify it's a valid date-time + expect(timestampVal.getFullYear()).to.be.at.least(2023); + expect(timestampVal).to.be.an.instanceof(Date); + } else if (typeof timestampVal === 'string') { + // If it's a string, verify it contains date and time components + expect(/\d{4}[-/]\d{1,2}[-/]\d{1,2}/.test(timestampVal)).to.be.true; // has date part + expect(/\d{1,2}:\d{1,2}(:\d{1,2})?/.test(timestampVal)).to.be.true; // has time part + } else { + // Otherwise just make sure it exists + expect(timestampVal).to.exist; + } + + // Test interval + expect(row).to.have.property('interval_day'); + + // Test array type + expect(row).to.have.property('array_val'); + const arrayVal = row.array_val; + + // Handle various ways arrays might be represented + if (Array.isArray(arrayVal)) { + expect(arrayVal).to.have.lengthOf(3); + expect(arrayVal).to.include.members([1, 2, 3]); + } else if (typeof arrayVal === 'string') { + // Sometimes arrays might be returned as strings like "[1,2,3]" + expect(arrayVal).to.include('1'); + expect(arrayVal).to.include('2'); + expect(arrayVal).to.include('3'); + } else { + // For other formats, just check it exists + expect(arrayVal).to.exist; + } + + // Test map type + expect(row).to.have.property('map_val'); + const mapVal = row.map_val; + + // Maps could be returned in several formats depending on the protocol version + if (typeof mapVal === 'object' && mapVal !== null && !Array.isArray(mapVal)) { + // If returned as a plain JavaScript object + expect(mapVal).to.have.property('a', 1); + expect(mapVal).to.have.property('b', 2); + expect(mapVal).to.have.property('c', 3); + } else if (typeof mapVal === 'string') { + // Sometimes might be serialized as string + expect(mapVal).to.include('a'); + expect(mapVal).to.include('b'); + expect(mapVal).to.include('c'); + expect(mapVal).to.include('1'); + expect(mapVal).to.include('2'); + expect(mapVal).to.include('3'); + } else { + // For other formats, just check it exists + expect(mapVal).to.exist; + } + + // Test struct type + expect(row).to.have.property('struct_val'); + const structVal = row.struct_val; + + // Structs could be represented differently based on protocol version + if (typeof structVal === 'object' && structVal !== null && !Array.isArray(structVal)) { + // If returned as a plain JavaScript object + expect(structVal).to.have.property('id', 42); + expect(structVal).to.have.property('name', 'test_name'); + expect(structVal).to.have.property('active', true); + } else if (typeof structVal === 'string') { + // If serialized as string + expect(structVal).to.include('42'); + expect(structVal).to.include('test_name'); + } else { + // For other formats, just check it exists + expect(structVal).to.exist; + } + + // Test null value + expect(row).to.have.property('null_val'); + expect(row.null_val).to.be.null; + }); + + it('should get catalogs', async () => { + const operation = await session.getCatalogs(); + const catalogs = await operation.fetchAll(); + await operation.close(); + + expect(catalogs).to.be.an('array'); + expect(catalogs.length).to.be.at.least(1); + expect(catalogs[0]).to.have.property('TABLE_CAT'); + }); + + it('should get schemas', async () => { + const operation = await session.getSchemas({ catalogName: config.catalog }); + const schemas = await operation.fetchAll(); + await operation.close(); + + expect(schemas).to.be.an('array'); + expect(schemas.length).to.be.at.least(1); + expect(schemas[0]).to.have.property('TABLE_SCHEM'); + }); + + it('should get table types', async () => { + const operation = await session.getTableTypes(); + const tableTypes = await operation.fetchAll(); + await operation.close(); + + expect(tableTypes).to.be.an('array'); + expect(tableTypes.length).to.be.at.least(1); + expect(tableTypes[0]).to.have.property('TABLE_TYPE'); + }); + + it('should get tables', async () => { + const operation = await session.getTables({ + catalogName: config.catalog, + schemaName: config.schema, + }); + const tables = await operation.fetchAll(); + await operation.close(); + + expect(tables).to.be.an('array'); + // There might not be any tables, so we don't assert on the length + if (tables.length > 0) { + expect(tables[0]).to.have.property('TABLE_NAME'); + } + }); + + it('should get columns from current schema', function (this: Mocha.Context) { + return (async () => { + // First get a table name from the current schema + const tablesOp = await session.getTables({ + catalogName: config.catalog, + schemaName: config.schema, + }); + const tables = await tablesOp.fetchAll(); + await tablesOp.close(); + + if (tables.length === 0) { + // eslint-disable-next-line no-console + console.log('No tables found in the schema, skipping column test'); + this.skip(); + return; + } + + const tableName = (tables[0] as any).TABLE_NAME; + + const operation = await session.getColumns({ + catalogName: config.catalog, + schemaName: config.schema, + tableName, + }); + const columns = await operation.fetchAll(); + await operation.close(); + + expect(columns).to.be.an('array'); + expect(columns.length).to.be.at.least(1); + expect(columns[0]).to.have.property('COLUMN_NAME'); + })(); + }); + }); + }); +}); diff --git a/tests/unit/DBSQLClient.test.ts b/tests/unit/DBSQLClient.test.ts index f4ac593f..f942c6b8 100644 --- a/tests/unit/DBSQLClient.test.ts +++ b/tests/unit/DBSQLClient.test.ts @@ -16,6 +16,7 @@ import IThriftClient from '../../lib/contracts/IThriftClient'; import IAuthentication from '../../lib/connection/contracts/IAuthentication'; import AuthProviderStub from './.stubs/AuthProviderStub'; import ConnectionProviderStub from './.stubs/ConnectionProviderStub'; +import { TProtocolVersion } from '../../thrift/TCLIService_types'; const connectOptions = { host: '127.0.0.1', @@ -155,6 +156,84 @@ describe('DBSQLClient.openSession', () => { expect(error.message).to.be.eq('DBSQLClient: not connected'); } }); + + it('should correctly pass server protocol version to session', async () => { + const client = new DBSQLClient(); + const thriftClient = new ThriftClientStub(); + sinon.stub(client, 'getClient').returns(Promise.resolve(thriftClient)); + + // Test with default protocol version (SPARK_CLI_SERVICE_PROTOCOL_V8) + { + const session = await client.openSession(); + expect(session).instanceOf(DBSQLSession); + expect((session as DBSQLSession)['serverProtocolVersion']).to.equal( + TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8, + ); + } + + { + thriftClient.openSessionResp = { + ...thriftClient.openSessionResp, + serverProtocolVersion: TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V7, + }; + + const session = await client.openSession(); + expect(session).instanceOf(DBSQLSession); + expect((session as DBSQLSession)['serverProtocolVersion']).to.equal( + TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V7, + ); + } + }); + + it('should affect session behavior based on protocol version', async () => { + const client = new DBSQLClient(); + const thriftClient = new ThriftClientStub(); + sinon.stub(client, 'getClient').returns(Promise.resolve(thriftClient)); + + // With protocol version V6 - should support async metadata operations + { + thriftClient.openSessionResp = { + ...thriftClient.openSessionResp, + serverProtocolVersion: TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V6, + }; + + const session = await client.openSession(); + expect(session).instanceOf(DBSQLSession); + + // Spy on driver.getTypeInfo to check if runAsync is set + const driver = await client.getDriver(); + const getTypeInfoSpy = sinon.spy(driver, 'getTypeInfo'); + + await session.getTypeInfo(); + + expect(getTypeInfoSpy.calledOnce).to.be.true; + expect(getTypeInfoSpy.firstCall.args[0].runAsync).to.be.true; + + getTypeInfoSpy.restore(); + } + + // With protocol version V5 - should NOT support async metadata operations + { + thriftClient.openSessionResp = { + ...thriftClient.openSessionResp, + serverProtocolVersion: TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V5, + }; + + const session = await client.openSession(); + expect(session).instanceOf(DBSQLSession); + + // Spy on driver.getTypeInfo to check if runAsync is undefined + const driver = await client.getDriver(); + const getTypeInfoSpy = sinon.spy(driver, 'getTypeInfo'); + + await session.getTypeInfo(); + + expect(getTypeInfoSpy.calledOnce).to.be.true; + expect(getTypeInfoSpy.firstCall.args[0].runAsync).to.be.undefined; + + getTypeInfoSpy.restore(); + } + }); }); describe('DBSQLClient.getClient', () => { diff --git a/tests/unit/DBSQLSession.test.ts b/tests/unit/DBSQLSession.test.ts index 460047f5..055483ad 100644 --- a/tests/unit/DBSQLSession.test.ts +++ b/tests/unit/DBSQLSession.test.ts @@ -5,7 +5,7 @@ import DBSQLSession, { numberToInt64 } from '../../lib/DBSQLSession'; import InfoValue from '../../lib/dto/InfoValue'; import Status from '../../lib/dto/Status'; import DBSQLOperation from '../../lib/DBSQLOperation'; -import { TSessionHandle } from '../../thrift/TCLIService_types'; +import { TSessionHandle, TProtocolVersion } from '../../thrift/TCLIService_types'; import ClientContextStub from './.stubs/ClientContextStub'; const sessionHandleStub: TSessionHandle = { @@ -105,6 +105,81 @@ describe('DBSQLSession', () => { } }); }); + + describe('executeStatement with different protocol versions', () => { + const protocolVersions = [ + { version: TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V1, desc: 'V1: no special features' }, + { version: TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V2, desc: 'V2: no special features' }, + { version: TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V3, desc: 'V3: cloud fetch' }, + { version: TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V4, desc: 'V4: multiple catalogs' }, + { version: TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V5, desc: 'V5: arrow metadata' }, + { version: TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V6, desc: 'V6: async metadata, arrow compression' }, + { version: TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V7, desc: 'V7: result persistence mode' }, + { version: TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8, desc: 'V8: parameterized queries' }, + ]; + + for (const { version, desc } of protocolVersions) { + it(`should properly format request with protocol version ${desc}`, async () => { + const context = new ClientContextStub(); + const driver = sinon.spy(context.driver); + const statement = 'SELECT * FROM table'; + const options = { + maxRows: 10, + queryTimeout: 100, + namedParameters: { param1: 'value1' }, + useCloudFetch: true, + useLZ4Compression: true, + }; + + const session = new DBSQLSession({ + handle: sessionHandleStub, + context, + serverProtocolVersion: version, + }); + + await session.executeStatement(statement, options); + + expect(driver.executeStatement.callCount).to.eq(1); + const req = driver.executeStatement.firstCall.args[0]; + + // Basic fields that should always be present + expect(req.sessionHandle.sessionId.guid).to.deep.equal(sessionHandleStub.sessionId.guid); + expect(req.sessionHandle.sessionId.secret).to.deep.equal(sessionHandleStub.sessionId.secret); + expect(req.statement).to.equal(statement); + expect(req.runAsync).to.be.true; + expect(req.queryTimeout).to.deep.equal(numberToInt64(options.queryTimeout)); + + // Fields that depend on protocol version + if (version >= TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8) { + expect(req.parameters).to.exist; + expect(req.parameters?.length).to.equal(1); + } else { + expect(req.parameters).to.not.exist; + } + + if (version >= TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V6) { + expect(req.canDecompressLZ4Result).to.be.true; + } else { + expect(req.canDecompressLZ4Result).to.not.exist; + } + + if (version >= TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V5) { + expect(req.canReadArrowResult).to.be.true; + expect(req.useArrowNativeTypes).to.not.be.undefined; + } else if (version >= TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V3) { + // V3 and V4 have canDownloadResult but not arrow-related fields + expect(req.canReadArrowResult).to.be.false; + expect(req.useArrowNativeTypes).to.not.exist; + expect(req.canDownloadResult).to.be.true; + } else { + // V1 and V2 don't have arrow or download features + expect(req.canReadArrowResult).to.be.false; + expect(req.useArrowNativeTypes).to.not.exist; + expect(req.canDownloadResult).to.not.exist; + } + }); + } + }); }); describe('getTypeInfo', () => { diff --git a/tests/unit/utils/protocolVersion.test.ts b/tests/unit/utils/protocolVersion.test.ts new file mode 100644 index 00000000..3469dfe8 --- /dev/null +++ b/tests/unit/utils/protocolVersion.test.ts @@ -0,0 +1,74 @@ +import { expect } from 'chai'; +import { TProtocolVersion } from '../../../thrift/TCLIService_types'; +import * as ProtocolVersion from '../../../lib/utils/protocolVersion'; + +describe('Protocol Version Utility - Parameterized Tests', () => { + // Define minimum protocol versions for each feature + const MIN_VERSION_CLOUD_FETCH = TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V3; + const MIN_VERSION_MULTIPLE_CATALOGS = TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V4; + const MIN_VERSION_ARROW_METADATA = TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V5; + const MIN_VERSION_ARROW_COMPRESSION = TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V6; + const MIN_VERSION_ASYNC_METADATA = TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V6; + const MIN_VERSION_RESULT_PERSISTENCE = TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V7; + const MIN_VERSION_PARAMETERIZED = TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8; + + // Create an array of all protocol versions to test against + const protocolVersions = [ + TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V10, + TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V1, + TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V2, + TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V3, + TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V4, + TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V5, + TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V6, + TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V7, + TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8, + ]; + + // Test each protocol version against each feature function + protocolVersions.forEach((version) => { + describe(`with protocol version ${version}`, () => { + it('supportsCloudFetch', () => { + const expected = version >= MIN_VERSION_CLOUD_FETCH; + const actual = ProtocolVersion.supportsCloudFetch(version); + expect(actual).to.equal(expected); + }); + + it('supportsMultipleCatalogs', () => { + const expected = version >= MIN_VERSION_MULTIPLE_CATALOGS; + const actual = ProtocolVersion.supportsMultipleCatalogs(version); + expect(actual).to.equal(expected); + }); + + it('supportsArrowMetadata', () => { + const expected = version >= MIN_VERSION_ARROW_METADATA; + const actual = ProtocolVersion.supportsArrowMetadata(version); + expect(actual).to.equal(expected); + }); + + it('supportsArrowCompression', () => { + const expected = version >= MIN_VERSION_ARROW_COMPRESSION; + const actual = ProtocolVersion.supportsArrowCompression(version); + expect(actual).to.equal(expected); + }); + + it('supportsAsyncMetadataOperations', () => { + const expected = version >= MIN_VERSION_ASYNC_METADATA; + const actual = ProtocolVersion.supportsAsyncMetadataOperations(version); + expect(actual).to.equal(expected); + }); + + it('supportsResultPersistenceMode', () => { + const expected = version >= MIN_VERSION_RESULT_PERSISTENCE; + const actual = ProtocolVersion.supportsResultPersistenceMode(version); + expect(actual).to.equal(expected); + }); + + it('supportsParameterizedQueries', () => { + const expected = version >= MIN_VERSION_PARAMETERIZED; + const actual = ProtocolVersion.supportsParameterizedQueries(version); + expect(actual).to.equal(expected); + }); + }); + }); +});