|
1 | | -/* |
| 1 | +/* |
2 | 2 | * Licensed to the Apache Software Foundation (ASF) under one or more |
3 | 3 | * contributor license agreements. See the NOTICE file distributed with |
4 | 4 | * this work for additional information regarding copyright ownership. |
|
15 | 15 | * limitations under the License. |
16 | 16 | */ |
17 | 17 |
|
18 | | -using System; |
19 | 18 | using System.Collections.Generic; |
20 | | -using System.Net.Http; |
21 | | -using System.Net.Http.Headers; |
| 19 | +using System.Threading; |
| 20 | +using System.Threading.Tasks; |
22 | 21 | using Apache.Arrow.Adbc.Drivers.Apache; |
23 | 22 | using Apache.Arrow.Adbc.Drivers.Apache.Spark; |
24 | | -using Thrift; |
25 | | -using Thrift.Transport; |
| 23 | +using Apache.Arrow.Adbc.Drivers.Apache.Spark.CloudFetch; |
| 24 | +using Apache.Arrow.Ipc; |
| 25 | +using Apache.Hive.Service.Rpc.Thrift; |
26 | 26 |
|
27 | 27 | namespace Apache.Arrow.Adbc.Drivers.Databricks |
28 | 28 | { |
29 | | - /// <summary> |
30 | | - /// Databricks-specific implementation of <see cref="AdbcConnection"/> |
31 | | - /// </summary> |
32 | | - internal class DatabricksConnection : SparkDatabricksConnection |
| 29 | + internal class DatabricksConnection : SparkHttpConnection |
33 | 30 | { |
34 | | - protected new const string ProductVersionDefault = "1.0.0"; |
35 | | - protected new const string DriverName = "ADBC Databricks Driver"; |
36 | | - private const string ArrowVersion = "1.0.0"; |
37 | | - private static readonly string s_userAgent = $"{DriverName.Replace(" ", "")}/{ProductVersionDefault}"; |
38 | | - |
39 | 31 | public DatabricksConnection(IReadOnlyDictionary<string, string> properties) : base(properties) |
40 | 32 | { |
41 | 33 | } |
42 | 34 |
|
43 | | - protected override TTransport CreateTransport() |
| 35 | + internal override IArrowArrayStream NewReader<T>(T statement, Schema schema, TGetResultSetMetadataResp? metadataResp = null) |
44 | 36 | { |
45 | | - // Assumption: parameters have already been validated. |
46 | | - Properties.TryGetValue(SparkParameters.HostName, out string? hostName); |
47 | | - Properties.TryGetValue(SparkParameters.Path, out string? path); |
48 | | - Properties.TryGetValue(SparkParameters.Port, out string? port); |
49 | | - Properties.TryGetValue(SparkParameters.AuthType, out string? authType); |
50 | | - if (!SparkAuthTypeParser.TryParse(authType, out SparkAuthType authTypeValue)) |
| 37 | + // Get result format from metadata response if available |
| 38 | + TSparkRowSetType resultFormat = TSparkRowSetType.ARROW_BASED_SET; |
| 39 | + bool isLz4Compressed = false; |
| 40 | + |
| 41 | + if (metadataResp != null) |
| 42 | + { |
| 43 | + if (metadataResp.__isset.resultFormat) |
| 44 | + { |
| 45 | + resultFormat = metadataResp.ResultFormat; |
| 46 | + } |
| 47 | + |
| 48 | + if (metadataResp.__isset.lz4Compressed) |
| 49 | + { |
| 50 | + isLz4Compressed = metadataResp.Lz4Compressed; |
| 51 | + } |
| 52 | + } |
| 53 | + |
| 54 | + // Choose the appropriate reader based on the result format |
| 55 | + if (resultFormat == TSparkRowSetType.URL_BASED_SET) |
51 | 56 | { |
52 | | - throw new ArgumentOutOfRangeException(SparkParameters.AuthType, authType, $"Unsupported {SparkParameters.AuthType} value."); |
| 57 | + return new SparkCloudFetchReader(statement, schema, isLz4Compressed); |
53 | 58 | } |
54 | | - Properties.TryGetValue(SparkParameters.Token, out string? token); |
55 | | - Properties.TryGetValue(SparkParameters.AccessToken, out string? access_token); |
56 | | - Properties.TryGetValue(AdbcOptions.Username, out string? username); |
57 | | - Properties.TryGetValue(AdbcOptions.Password, out string? password); |
58 | | - Properties.TryGetValue(AdbcOptions.Uri, out string? uri); |
| 59 | + else |
| 60 | + { |
| 61 | + return new DatabricksReader(statement, schema); |
| 62 | + } |
| 63 | + } |
59 | 64 |
|
60 | | - Uri baseAddress = GetBaseAddress(uri, hostName, path, port, SparkParameters.HostName); |
61 | | - AuthenticationHeaderValue? authenticationHeaderValue = GetAuthenticationHeaderValue(authTypeValue, token, username, password, access_token); |
| 65 | + internal override SchemaParser SchemaParser => new DatabricksSchemaParser(); |
62 | 66 |
|
63 | | - HttpClientHandler httpClientHandler = NewHttpClientHandler(); |
64 | | - Lz4CompressionHandler lz4CompressionHandler = new Lz4CompressionHandler { InnerHandler = httpClientHandler }; |
65 | | - HttpClient httpClient = new(lz4CompressionHandler); |
66 | | - httpClient.BaseAddress = baseAddress; |
67 | | - httpClient.DefaultRequestHeaders.Authorization = authenticationHeaderValue; |
68 | | - httpClient.DefaultRequestHeaders.UserAgent.ParseAdd(s_userAgent); |
69 | | - httpClient.DefaultRequestHeaders.AcceptEncoding.Clear(); |
70 | | - httpClient.DefaultRequestHeaders.AcceptEncoding.Add(new StringWithQualityHeaderValue("identity")); |
71 | | - httpClient.DefaultRequestHeaders.ExpectContinue = false; |
| 67 | + //internal override SparkServerType ServerType => SparkServerType.Databricks; |
72 | 68 |
|
73 | | - TConfiguration config = new(); |
74 | | - ThriftHttpTransport transport = new(httpClient, config) |
| 69 | + protected override TOpenSessionReq CreateSessionRequest() |
| 70 | + { |
| 71 | + var req = new TOpenSessionReq |
75 | 72 | { |
76 | | - // This value can only be set before the first call/request. So if a new value for query timeout |
77 | | - // is set, we won't be able to update the value. Setting to ~infinite and relying on cancellation token |
78 | | - // to ensure cancelled correctly. |
79 | | - ConnectTimeout = int.MaxValue, |
| 73 | + Client_protocol = TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V7, |
| 74 | + Client_protocol_i64 = (long)TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V7, |
| 75 | + CanUseMultipleCatalogs = true, |
80 | 76 | }; |
81 | | - return transport; |
| 77 | + return req; |
82 | 78 | } |
| 79 | + |
| 80 | + protected override Task<TGetResultSetMetadataResp> GetResultSetMetadataAsync(TGetSchemasResp response, CancellationToken cancellationToken = default) => |
| 81 | + Task.FromResult(response.DirectResults.ResultSetMetadata); |
| 82 | + protected override Task<TGetResultSetMetadataResp> GetResultSetMetadataAsync(TGetCatalogsResp response, CancellationToken cancellationToken = default) => |
| 83 | + Task.FromResult(response.DirectResults.ResultSetMetadata); |
| 84 | + protected override Task<TGetResultSetMetadataResp> GetResultSetMetadataAsync(TGetColumnsResp response, CancellationToken cancellationToken = default) => |
| 85 | + Task.FromResult(response.DirectResults.ResultSetMetadata); |
| 86 | + protected override Task<TGetResultSetMetadataResp> GetResultSetMetadataAsync(TGetTablesResp response, CancellationToken cancellationToken = default) => |
| 87 | + Task.FromResult(response.DirectResults.ResultSetMetadata); |
| 88 | + protected internal override Task<TGetResultSetMetadataResp> GetResultSetMetadataAsync(TGetPrimaryKeysResp response, CancellationToken cancellationToken = default) => |
| 89 | + Task.FromResult(response.DirectResults.ResultSetMetadata); |
| 90 | + |
| 91 | + protected override Task<TRowSet> GetRowSetAsync(TGetTableTypesResp response, CancellationToken cancellationToken = default) => |
| 92 | + Task.FromResult(response.DirectResults.ResultSet.Results); |
| 93 | + protected override Task<TRowSet> GetRowSetAsync(TGetColumnsResp response, CancellationToken cancellationToken = default) => |
| 94 | + Task.FromResult(response.DirectResults.ResultSet.Results); |
| 95 | + protected override Task<TRowSet> GetRowSetAsync(TGetTablesResp response, CancellationToken cancellationToken = default) => |
| 96 | + Task.FromResult(response.DirectResults.ResultSet.Results); |
| 97 | + protected override Task<TRowSet> GetRowSetAsync(TGetCatalogsResp response, CancellationToken cancellationToken = default) => |
| 98 | + Task.FromResult(response.DirectResults.ResultSet.Results); |
| 99 | + protected override Task<TRowSet> GetRowSetAsync(TGetSchemasResp response, CancellationToken cancellationToken = default) => |
| 100 | + Task.FromResult(response.DirectResults.ResultSet.Results); |
| 101 | + protected internal override Task<TRowSet> GetRowSetAsync(TGetPrimaryKeysResp response, CancellationToken cancellationToken = default) => |
| 102 | + Task.FromResult(response.DirectResults.ResultSet.Results); |
83 | 103 | } |
84 | 104 | } |
0 commit comments