Skip to content

Commit e532053

Browse files
zeroshadeianmcook
andauthored
feat(c/driver_manager,rust/driver_manager): handle virtual environments in driver manager (#3320)
Updates the driver manager logic to handle virtual environments used by venv and conda. venv will set a `VIRTUAL_ENV` environment variable while conda will set `CONDA_PREFIX`. If these environment variables are set, then they are added to the search list of `get_search_paths` / `GetSearchPaths` when the `LOAD_FLAG_SEARCH_ENV` flag is set. Also updates the docs to mention this. --------- Co-authored-by: Ian Cook <ianmcook@gmail.com>
1 parent e799e65 commit e532053

File tree

19 files changed

+510
-146
lines changed

19 files changed

+510
-146
lines changed

c/driver_manager/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,11 @@ foreach(LIB_TARGET ${ADBC_LIBRARIES})
6969
${REPOSITORY_ROOT}/c/vendor
7070
${REPOSITORY_ROOT}/c/driver)
7171
target_compile_definitions(${LIB_TARGET} PRIVATE ADBC_EXPORTING)
72+
if("$ENV{CONDA_BUILD}" STREQUAL "1")
73+
target_compile_definitions(${LIB_TARGET} PRIVATE ADBC_CONDA_BUILD=1)
74+
else()
75+
target_compile_definitions(${LIB_TARGET} PRIVATE ADBC_CONDA_BUILD=0)
76+
endif()
7277
endforeach()
7378

7479
if(ADBC_BUILD_TESTS)

c/driver_manager/adbc_driver_manager.cc

Lines changed: 109 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ std::wstring Utf8Decode(const std::string& str) {
159159

160160
#else
161161
using char_type = char;
162-
#endif
162+
#endif // _WIN32
163163

164164
// Driver state
165165
struct DriverInfo {
@@ -239,7 +239,7 @@ AdbcStatusCode LoadDriverFromRegistry(HKEY root, const std::wstring& driver_name
239239
}
240240
return ADBC_STATUS_OK;
241241
}
242-
#endif
242+
#endif // _WIN32
243243

244244
AdbcStatusCode LoadDriverManifest(const std::filesystem::path& driver_manifest,
245245
DriverInfo& info, struct AdbcError* error) {
@@ -273,14 +273,39 @@ AdbcStatusCode LoadDriverManifest(const std::filesystem::path& driver_manifest,
273273
return ADBC_STATUS_OK;
274274
}
275275

276+
std::vector<std::filesystem::path> GetEnvPaths(const char_type* env_var) {
277+
#ifdef _WIN32
278+
size_t required_size;
279+
280+
_wgetenv_s(&required_size, NULL, 0, env_var);
281+
if (required_size == 0) {
282+
return {};
283+
}
284+
285+
std::wstring path_var;
286+
path_var.resize(required_size);
287+
_wgetenv_s(&required_size, path_var.data(), required_size, env_var);
288+
auto path = Utf8Encode(path_var);
289+
#else
290+
const char* path_var = std::getenv(env_var);
291+
if (!path_var) {
292+
return {};
293+
}
294+
std::string path(path_var);
295+
#endif // _WIN32
296+
return InternalAdbcParsePath(path);
297+
}
298+
276299
std::vector<std::filesystem::path> GetSearchPaths(const AdbcLoadFlags levels) {
277300
std::vector<std::filesystem::path> paths;
278301
if (levels & ADBC_LOAD_FLAG_SEARCH_ENV) {
302+
#ifdef _WIN32
303+
static const wchar_t* env_var = L"ADBC_CONFIG_PATH";
304+
#else
305+
static const char* env_var = "ADBC_CONFIG_PATH";
306+
#endif // _WIN32
279307
// Check the ADBC_CONFIG_PATH environment variable
280-
const char* env_path = std::getenv("ADBC_CONFIG_PATH");
281-
if (env_path) {
282-
paths = InternalAdbcParsePath(env_path);
283-
}
308+
paths = GetEnvPaths(env_var);
284309
}
285310

286311
if (levels & ADBC_LOAD_FLAG_SEARCH_USER) {
@@ -305,7 +330,7 @@ std::vector<std::filesystem::path> GetSearchPaths(const AdbcLoadFlags levels) {
305330
if (std::filesystem::exists(system_config_dir)) {
306331
paths.push_back(system_config_dir);
307332
}
308-
#endif
333+
#endif // defined(__APPLE__)
309334
}
310335

311336
return paths;
@@ -319,7 +344,7 @@ bool HasExtension(const std::filesystem::path& path, const std::string& ext) {
319344
_wcsnicmp(path_ext.data(), wext.data(), wext.size()) == 0;
320345
#else
321346
return path.extension() == ext;
322-
#endif
347+
#endif // _WIN32
323348
}
324349

325350
/// A driver DLL.
@@ -344,9 +369,10 @@ struct ManagedLibrary {
344369
// release() from the DLL - how to handle this?
345370
}
346371

347-
AdbcStatusCode GetDriverInfo(const std::string_view driver_name,
348-
const AdbcLoadFlags load_options, DriverInfo& info,
349-
struct AdbcError* error) {
372+
AdbcStatusCode GetDriverInfo(
373+
const std::string_view driver_name, const AdbcLoadFlags load_options,
374+
const std::vector<std::filesystem::path>& additional_search_paths, DriverInfo& info,
375+
struct AdbcError* error) {
350376
if (driver_name.empty()) {
351377
SetError(error, "Driver name is empty");
352378
return ADBC_STATUS_INVALID_ARGUMENT;
@@ -405,7 +431,7 @@ struct ManagedLibrary {
405431
static const std::string kPlatformLibrarySuffix = ".dylib";
406432
#else
407433
static const std::string kPlatformLibrarySuffix = ".so";
408-
#endif
434+
#endif // defined(_WIN32)
409435
if (HasExtension(driver_path, kPlatformLibrarySuffix)) {
410436
info.lib_path = driver_path;
411437
return Load(driver_path.c_str(), error);
@@ -418,7 +444,7 @@ struct ManagedLibrary {
418444

419445
// not an absolute path, no extension. Let's search the configured paths
420446
// based on the options
421-
return FindDriver(driver_path, load_options, info, error);
447+
return FindDriver(driver_path, load_options, additional_search_paths, info, error);
422448
}
423449

424450
AdbcStatusCode SearchPaths(const std::filesystem::path& driver_path,
@@ -446,27 +472,52 @@ struct ManagedLibrary {
446472
return ADBC_STATUS_NOT_FOUND;
447473
}
448474

449-
AdbcStatusCode FindDriver(const std::filesystem::path& driver_path,
450-
const AdbcLoadFlags load_options, DriverInfo& info,
451-
struct AdbcError* error) {
475+
AdbcStatusCode FindDriver(
476+
const std::filesystem::path& driver_path, const AdbcLoadFlags load_options,
477+
const std::vector<std::filesystem::path>& additional_search_paths, DriverInfo& info,
478+
struct AdbcError* error) {
452479
if (driver_path.empty()) {
453480
SetError(error, "Driver path is empty");
454481
return ADBC_STATUS_INVALID_ARGUMENT;
455482
}
456483

484+
{
485+
// First search the paths in the env var `ADBC_CONFIG_PATH`.
486+
// Then search the runtime application-defined additional search paths.
487+
auto search_paths = GetSearchPaths(load_options & ADBC_LOAD_FLAG_SEARCH_ENV);
488+
search_paths.insert(search_paths.end(), additional_search_paths.begin(),
489+
additional_search_paths.end());
490+
491+
#if ADBC_CONDA_BUILD
492+
// Then, if this is a conda build, search in the conda environment if
493+
// it is activated.
494+
if (load_options & ADBC_LOAD_FLAG_SEARCH_ENV) {
457495
#ifdef _WIN32
458-
// windows is slightly more complex since we also need to check registry keys
459-
// so we can't just grab the search paths only.
460-
if (load_options & ADBC_LOAD_FLAG_SEARCH_ENV) {
461-
auto search_paths = GetSearchPaths(ADBC_LOAD_FLAG_SEARCH_ENV);
496+
const wchar_t* conda_name = L"CONDA_PREFIX";
497+
#else
498+
const char* conda_name = "CONDA_PREFIX";
499+
#endif // _WIN32
500+
auto venv = GetEnvPaths(conda_name);
501+
if (!venv.empty()) {
502+
for (const auto& venv_path : venv) {
503+
search_paths.push_back(venv_path / "etc" / "adbc");
504+
}
505+
}
506+
}
507+
#endif // ADBC_CONDA_BUILD
508+
462509
auto status = SearchPaths(driver_path, search_paths, info, error);
463510
if (status == ADBC_STATUS_OK) {
464511
return status;
465512
}
466513
}
467514

515+
// We searched environment paths and additional search paths (if they
516+
// exist), so now search the rest.
517+
#ifdef _WIN32
518+
// On Windows, check registry keys, not just search paths.
468519
if (load_options & ADBC_LOAD_FLAG_SEARCH_USER) {
469-
// Check the user registry for the driver
520+
// Check the user registry for the driver.
470521
auto status =
471522
LoadDriverFromRegistry(HKEY_CURRENT_USER, driver_path.native(), info, error);
472523
if (status == ADBC_STATUS_OK) {
@@ -481,7 +532,7 @@ struct ManagedLibrary {
481532
}
482533

483534
if (load_options & ADBC_LOAD_FLAG_SEARCH_SYSTEM) {
484-
// Check the system registry for the driver
535+
// Check the system registry for the driver.
485536
auto status =
486537
LoadDriverFromRegistry(HKEY_LOCAL_MACHINE, driver_path.native(), info, error);
487538
if (status == ADBC_STATUS_OK) {
@@ -498,17 +549,17 @@ struct ManagedLibrary {
498549
info.lib_path = driver_path;
499550
return Load(driver_path.c_str(), error);
500551
#else
501-
// Otherwise, search the configured paths
502-
auto search_paths = GetSearchPaths(load_options);
552+
// Otherwise, search the configured paths.
553+
auto search_paths = GetSearchPaths(load_options & ~ADBC_LOAD_FLAG_SEARCH_ENV);
503554
auto status = SearchPaths(driver_path, search_paths, info, error);
504555
if (status == ADBC_STATUS_NOT_FOUND) {
505556
// If we reach here, we didn't find the driver in any of the paths
506-
// so let's just attempt to load it as default behavior
557+
// so let's just attempt to load it as default behavior.
507558
info.lib_path = driver_path;
508559
return Load(driver_path.c_str(), error);
509560
}
510561
return status;
511-
#endif
562+
#endif // _WIN32
512563
}
513564

514565
AdbcStatusCode Load(const char_type* library, struct AdbcError* error) {
@@ -991,6 +1042,7 @@ struct TempDatabase {
9911042
std::string entrypoint;
9921043
AdbcDriverInitFunc init_func = nullptr;
9931044
AdbcLoadFlags load_flags = ADBC_LOAD_FLAG_ALLOW_RELATIVE_PATHS;
1045+
std::string additional_search_path_list;
9941046
};
9951047

9961048
/// Temporary state while the database is being configured.
@@ -1043,7 +1095,7 @@ std::filesystem::path InternalAdbcUserConfigDir() {
10431095
if (!config_dir.empty()) {
10441096
config_dir /= "adbc";
10451097
}
1046-
#endif
1098+
#endif // defined(_WIN32)
10471099

10481100
return config_dir;
10491101
}
@@ -1362,6 +1414,22 @@ AdbcStatusCode AdbcDriverManagerDatabaseSetLoadFlags(struct AdbcDatabase* databa
13621414
return ADBC_STATUS_OK;
13631415
}
13641416

1417+
AdbcStatusCode AdbcDriverManagerDatabaseSetAdditionalSearchPathList(
1418+
struct AdbcDatabase* database, const char* path_list, struct AdbcError* error) {
1419+
if (database->private_driver) {
1420+
SetError(error, "Cannot SetAdditionalSearchPathList after AdbcDatabaseInit");
1421+
return ADBC_STATUS_INVALID_STATE;
1422+
}
1423+
1424+
TempDatabase* args = reinterpret_cast<TempDatabase*>(database->private_data);
1425+
if (path_list) {
1426+
args->additional_search_path_list.assign(path_list);
1427+
} else {
1428+
args->additional_search_path_list.clear();
1429+
}
1430+
return ADBC_STATUS_OK;
1431+
}
1432+
13651433
AdbcStatusCode AdbcDriverManagerDatabaseSetInitFunc(struct AdbcDatabase* database,
13661434
AdbcDriverInitFunc init_func,
13671435
struct AdbcError* error) {
@@ -1399,10 +1467,12 @@ AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError*
13991467
} else if (!args->entrypoint.empty()) {
14001468
status = AdbcFindLoadDriver(args->driver.c_str(), args->entrypoint.c_str(),
14011469
ADBC_VERSION_1_1_0, args->load_flags,
1470+
args->additional_search_path_list.data(),
14021471
database->private_driver, error);
14031472
} else {
1404-
status = AdbcFindLoadDriver(args->driver.c_str(), nullptr, ADBC_VERSION_1_1_0,
1405-
args->load_flags, database->private_driver, error);
1473+
status = AdbcFindLoadDriver(
1474+
args->driver.c_str(), nullptr, ADBC_VERSION_1_1_0, args->load_flags,
1475+
args->additional_search_path_list.data(), database->private_driver, error);
14061476
}
14071477

14081478
if (status != ADBC_STATUS_OK) {
@@ -2134,6 +2204,7 @@ const char* AdbcStatusCodeMessage(AdbcStatusCode code) {
21342204

21352205
AdbcStatusCode AdbcFindLoadDriver(const char* driver_name, const char* entrypoint,
21362206
const int version, const AdbcLoadFlags load_options,
2207+
const char* additional_search_path_list,
21372208
void* raw_driver, struct AdbcError* error) {
21382209
AdbcDriverInitFunc init_func = nullptr;
21392210
std::string error_message;
@@ -2163,7 +2234,13 @@ AdbcStatusCode AdbcFindLoadDriver(const char* driver_name, const char* entrypoin
21632234
info.entrypoint = entrypoint;
21642235
}
21652236

2166-
AdbcStatusCode status = library.GetDriverInfo(driver_name, load_options, info, error);
2237+
std::vector<std::filesystem::path> additional_paths;
2238+
if (additional_search_path_list) {
2239+
additional_paths = InternalAdbcParsePath(additional_search_path_list);
2240+
}
2241+
2242+
AdbcStatusCode status =
2243+
library.GetDriverInfo(driver_name, load_options, additional_paths, info, error);
21672244
if (status != ADBC_STATUS_OK) {
21682245
driver->release = nullptr;
21692246
return status;
@@ -2205,7 +2282,8 @@ AdbcStatusCode AdbcLoadDriver(const char* driver_name, const char* entrypoint,
22052282
// but don't enable searching for manifests by default. It will need to be explicitly
22062283
// enabled by calling AdbcFindLoadDriver directly.
22072284
return AdbcFindLoadDriver(driver_name, entrypoint, version,
2208-
ADBC_LOAD_FLAG_ALLOW_RELATIVE_PATHS, raw_driver, error);
2285+
ADBC_LOAD_FLAG_ALLOW_RELATIVE_PATHS, nullptr, raw_driver,
2286+
error);
22092287
}
22102288

22112289
AdbcStatusCode AdbcLoadDriverFromInitFunc(AdbcDriverInitFunc init_func, int version,

0 commit comments

Comments
 (0)