Skip to content

Commit aca308f

Browse files
committed
Register arrow compute calls
`arrow-odbc-spi-impl-test` pass locally. Remove `RUN_ALL_TESTS` that wasn't needed; it is no longer needed after fix of GH-47434.
1 parent 16ceade commit aca308f

File tree

5 files changed

+38
-15
lines changed

5 files changed

+38
-15
lines changed

cpp/src/arrow/flight/sql/odbc/flight_sql/CMakeLists.txt

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,12 @@ if(WIN32)
9999
system_dsn.cc)
100100
endif()
101101

102-
target_link_libraries(arrow_odbc_spi_impl PUBLIC odbcabstraction arrow_flight_sql_shared)
102+
target_link_libraries(arrow_odbc_spi_impl PUBLIC odbcabstraction arrow_flight_sql_shared
103+
arrow_compute_shared Boost::locale)
103104

104-
if(MSVC)
105-
target_link_libraries(arrow_odbc_spi_impl PUBLIC Boost::locale)
105+
# Link libraries on MINGW64 only
106+
if((MINGW AND CMAKE_CXX_COMPILER_ID STREQUAL "GNU") OR APPLE)
107+
target_link_libraries(arrow_odbc_spi_impl PUBLIC ${ODBCINST})
106108
endif()
107109

108110
set_target_properties(arrow_odbc_spi_impl
@@ -121,7 +123,7 @@ set_target_properties(arrow_odbc_spi_impl_cli
121123
target_link_libraries(arrow_odbc_spi_impl_cli arrow_odbc_spi_impl)
122124

123125
# Unit tests
124-
add_arrow_test(arrow_odbc_spi_impl_test
126+
add_arrow_test(odbc_spi_impl_test
125127
SOURCES
126128
accessors/boolean_array_accessor_test.cc
127129
accessors/binary_array_accessor_test.cc

cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_connection_test.cc

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,3 @@ TEST(PopulateCallOptionsTest, GenericOptionWithSpaces) {
204204

205205
} // namespace flight_sql
206206
} // namespace driver
207-
208-
int main(int argc, char** argv) {
209-
::testing::InitGoogleTest(&argc, argv);
210-
return RUN_ALL_TESTS();
211-
}

cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_driver.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
// under the License.
1717

1818
#include "arrow/flight/sql/odbc/flight_sql/include/flight_sql/flight_sql_driver.h"
19+
#include "arrow/compute/api.h"
1920
#include "arrow/flight/sql/odbc/flight_sql/flight_sql_connection.h"
21+
#include "arrow/flight/sql/odbc/flight_sql/utils.h"
2022
#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/platform.h"
2123
#include "arrow/util/io_util.h"
2224
#include "arrow/util/logging.h"
@@ -33,6 +35,8 @@ using odbcabstraction::OdbcVersion;
3335

3436
FlightSqlDriver::FlightSqlDriver()
3537
: diagnostics_("Apache Arrow", "Flight SQL", OdbcVersion::V_3), version_("0.9.0.0") {
38+
RegisterComputeKernels();
39+
// Register log after compute kernels check to avoid segfaults
3640
RegisterLog();
3741
}
3842

@@ -52,6 +56,17 @@ odbcabstraction::Diagnostics& FlightSqlDriver::GetDiagnostics() { return diagnos
5256

5357
void FlightSqlDriver::SetVersion(std::string version) { version_ = std::move(version); }
5458

59+
void FlightSqlDriver::RegisterComputeKernels() {
60+
auto registry = arrow::compute::GetFunctionRegistry();
61+
62+
// strptime is one of the required compute functions
63+
auto strptime_func = registry->GetFunction("strptime");
64+
if (!strptime_func.ok()) {
65+
// Register Kernel functions to library
66+
ThrowIfNotOK(arrow::compute::Initialize());
67+
}
68+
}
69+
5570
void FlightSqlDriver::RegisterLog() {
5671
std::string log_level_str = arrow::internal::GetEnvVar(kODBCLogLevel)
5772
.Map(arrow::internal::AsciiToLower)

cpp/src/arrow/flight/sql/odbc/flight_sql/include/flight_sql/flight_sql_driver.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ class FlightSqlDriver : public odbcabstraction::Driver {
3939

4040
void SetVersion(std::string version) override;
4141

42+
/// Register Arrow Compute kernels once.
43+
void RegisterComputeKernels();
44+
4245
void RegisterLog() override;
4346
};
4447

cpp/src/arrow/flight/sql/odbc/flight_sql/utils_test.cc

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/calendar_utils.h"
2121

22+
#include "arrow/compute/initialize.h"
2223
#include "arrow/testing/builder.h"
2324
#include "arrow/testing/gtest_util.h"
2425
#include "arrow/testing/util.h"
@@ -27,6 +28,13 @@
2728
namespace driver {
2829
namespace flight_sql {
2930

31+
class UtilTestsWithCompute : public ::testing::Test {
32+
public:
33+
// This must be done before using the compute kernels in order to
34+
// register them to the FunctionRegistry.
35+
void SetUp() override { ASSERT_OK(arrow::compute::Initialize()); }
36+
};
37+
3038
void AssertConvertedArray(const std::shared_ptr<arrow::Array>& expected_array,
3139
const std::shared_ptr<arrow::Array>& converted_array,
3240
uint64_t size, arrow::Type::type arrow_type) {
@@ -80,7 +88,7 @@ void TestTime64ArrayConversion(const std::vector<int64_t>& input,
8088
AssertConvertedArray(expected_array, converted_array, input.size(), arrow_type);
8189
}
8290

83-
TEST(Utils, Time32ToTimeStampArray) {
91+
TEST_F(UtilTestsWithCompute, Time32ToTimeStampArray) {
8492
std::vector<int32_t> input_data = {14896, 17820};
8593

8694
const auto seconds_from_epoch = odbcabstraction::GetTodayTimeFromEpoch();
@@ -100,7 +108,7 @@ TEST(Utils, Time32ToTimeStampArray) {
100108
arrow::Type::TIMESTAMP);
101109
}
102110

103-
TEST(Utils, Time64ToTimeStampArray) {
111+
TEST_F(UtilTestsWithCompute, Time64ToTimeStampArray) {
104112
std::vector<int64_t> input_data = {1579489200000, 1646881200000};
105113

106114
const auto seconds_from_epoch = odbcabstraction::GetTodayTimeFromEpoch();
@@ -120,7 +128,7 @@ TEST(Utils, Time64ToTimeStampArray) {
120128
arrow::Type::TIMESTAMP);
121129
}
122130

123-
TEST(Utils, StringToDateArray) {
131+
TEST_F(UtilTestsWithCompute, StringToDateArray) {
124132
std::shared_ptr<arrow::Array> expected;
125133
arrow::ArrayFromVector<arrow::Date64Type, int64_t>({1579489200000, 1646881200000},
126134
&expected);
@@ -129,7 +137,7 @@ TEST(Utils, StringToDateArray) {
129137
odbcabstraction::CDataType_DATE, arrow::Type::DATE64);
130138
}
131139

132-
TEST(Utils, StringToTimeArray) {
140+
TEST_F(UtilTestsWithCompute, StringToTimeArray) {
133141
std::shared_ptr<arrow::Array> expected;
134142
arrow::ArrayFromVector<arrow::Time64Type, int64_t>(
135143
time64(arrow::TimeUnit::MICRO), {36000000000, 43200000000}, &expected);
@@ -138,15 +146,15 @@ TEST(Utils, StringToTimeArray) {
138146
arrow::Type::TIME64);
139147
}
140148

141-
TEST(Utils, ConvertSqlPatternToRegexString) {
149+
TEST_F(UtilTestsWithCompute, ConvertSqlPatternToRegexString) {
142150
ASSERT_EQ(std::string("XY"), ConvertSqlPatternToRegexString("XY"));
143151
ASSERT_EQ(std::string("X.Y"), ConvertSqlPatternToRegexString("X_Y"));
144152
ASSERT_EQ(std::string("X.*Y"), ConvertSqlPatternToRegexString("X%Y"));
145153
ASSERT_EQ(std::string("X%Y"), ConvertSqlPatternToRegexString("X\\%Y"));
146154
ASSERT_EQ(std::string("X_Y"), ConvertSqlPatternToRegexString("X\\_Y"));
147155
}
148156

149-
TEST(Utils, ConvertToDBMSVer) {
157+
TEST_F(UtilTestsWithCompute, ConvertToDBMSVer) {
150158
ASSERT_EQ(std::string("01.02.0003"), ConvertToDBMSVer("1.2.3"));
151159
ASSERT_EQ(std::string("01.02.0003.0"), ConvertToDBMSVer("1.2.3.0"));
152160
ASSERT_EQ(std::string("01.02.0000"), ConvertToDBMSVer("1.2"));

0 commit comments

Comments
 (0)