Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 61 additions & 16 deletions c/driver_manager/adbc_driver_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,35 @@ void AddSearchPathsToError(const SearchPaths& search_paths, std::string& error_m
}
}

// Generate a note for the error message if the library name has potentially
// non-printable (or really non-ASCII-printable-range) characters. Oblivious
// to Unicode and locales.
std::string CheckNonPrintableLibraryName(const std::string& name) {
// We could use std::isprint, but that requires locales; prefer a
// simpler check for out-of-ASCII-range.
bool has_non_printable = std::any_of(name.begin(), name.end(), [](char c) {
int v = static_cast<int>(c);
return v < 32 || v > 127;
});
if (!has_non_printable) return "";

std::string error_message = "Note: driver name may have non-printable characters: `";
// TODO(lidavidm): we can simplify with C++20 <format>
for (char c : name) {
int v = static_cast<int>(c);
if (v < 32 || v > 127) {
error_message += "\\x";
char buf[3];
std::snprintf(buf, sizeof(buf), "%02x", v & 0xFF);
error_message += buf;
} else {
error_message += c;
}
}
error_message += "`";
return error_message;
}

// Platform-specific helpers

#if defined(_WIN32)
Expand Down Expand Up @@ -230,6 +259,7 @@ struct OwnedError {

#ifdef _WIN32
using char_type = wchar_t;
using string_type = std::wstring;

std::string Utf8Encode(const std::wstring& wstr) {
if (wstr.empty()) return std::string();
Expand All @@ -253,6 +283,7 @@ std::wstring Utf8Decode(const std::string& str) {

#else
using char_type = char;
using string_type = std::string;
#endif // _WIN32

/// \brief The location and entrypoint of a resolved driver.
Expand Down Expand Up @@ -586,15 +617,15 @@ struct ManagedLibrary {

auto status = LoadDriverManifest(driver_path, info, error);
if (status == ADBC_STATUS_OK) {
return Load(info.lib_path.c_str(), {}, error);
return Load(info.lib_path.native(), {}, error);
}
return status;
}

// if the extension is not .toml, then just try to load the provided
// path as if it was an absolute path to a driver library
info.lib_path = driver_path;
return Load(driver_path.c_str(), {}, error);
return Load(driver_path.native(), {}, error);
}

if (driver_path.is_absolute()) {
Expand All @@ -604,14 +635,14 @@ struct ManagedLibrary {
if (std::filesystem::exists(driver_path)) {
auto status = LoadDriverManifest(driver_path, info, error);
if (status == ADBC_STATUS_OK) {
return Load(info.lib_path.c_str(), {}, error);
return Load(info.lib_path.native(), {}, error);
}
}

driver_path.replace_extension("");
// otherwise just try to load the provided path as if it was an absolute path
info.lib_path = driver_path;
return Load(driver_path.c_str(), {}, error);
return Load(driver_path.native(), {}, error);
}

if (driver_path.has_extension()) {
Expand All @@ -629,7 +660,7 @@ struct ManagedLibrary {
#endif // defined(_WIN32)
if (HasExtension(driver_path, kPlatformLibrarySuffix)) {
info.lib_path = driver_path;
return Load(driver_path.c_str(), {}, error);
return Load(driver_path.native(), {}, error);
}

SetError(error, "Driver name has unrecognized extension: " +
Expand Down Expand Up @@ -674,7 +705,7 @@ struct ManagedLibrary {
auto status = LoadDriverManifest(full_path, info, &intermediate_error.error);
if (status == ADBC_STATUS_OK) {
// Don't pass attempted_paths here; we'll generate the error at a higher level
status = Load(info.lib_path.c_str(), {}, &intermediate_error.error);
status = Load(info.lib_path.native(), {}, &intermediate_error.error);
if (status == ADBC_STATUS_OK) {
return status;
}
Expand Down Expand Up @@ -718,7 +749,7 @@ struct ManagedLibrary {
// remove the .toml extension; Load will add the DLL/SO/DYLIB suffix
full_path.replace_extension("");
// Don't pass error here - it'll be suppressed anyways
auto status = Load(full_path.c_str(), {}, nullptr);
auto status = Load(full_path.native(), {}, nullptr);
if (status == ADBC_STATUS_OK) {
info.lib_path = full_path;
return status;
Expand Down Expand Up @@ -797,7 +828,7 @@ struct ManagedLibrary {
auto status =
LoadDriverFromRegistry(HKEY_CURRENT_USER, driver_path.native(), info, error);
if (status == ADBC_STATUS_OK) {
return Load(info.lib_path.c_str(), {}, error);
return Load(info.lib_path.native(), {}, error);
}
if (error && error->message) {
std::string message = "HKEY_CURRENT_USER\\"s;
Expand All @@ -824,7 +855,7 @@ struct ManagedLibrary {
auto status =
LoadDriverFromRegistry(HKEY_LOCAL_MACHINE, driver_path.native(), info, error);
if (status == ADBC_STATUS_OK) {
return Load(info.lib_path.c_str(), {}, error);
return Load(info.lib_path.native(), {}, error);
}
if (error && error->message) {
std::string message = "HKEY_LOCAL_MACHINE\\"s;
Expand All @@ -848,7 +879,7 @@ struct ManagedLibrary {
}

info.lib_path = driver_path;
return Load(driver_path.c_str(), search_paths, error);
return Load(driver_path.native(), search_paths, error);
#else
// Otherwise, search the configured paths.
SearchPaths more_search_paths =
Expand Down Expand Up @@ -880,22 +911,23 @@ struct ManagedLibrary {
search_paths.insert(search_paths.end(), more_search_paths.begin(),
more_search_paths.end());
info.lib_path = driver_path;
return Load(driver_path.c_str(), search_paths, error);
return Load(driver_path.native(), search_paths, error);
}
return status;
#endif // _WIN32
}

/// \return ADBC_STATUS_NOT_FOUND if the driver shared library could not be
/// found, ADBC_STATUS_OK otherwise
AdbcStatusCode Load(const char_type* library, const SearchPaths& attempted_paths,
AdbcStatusCode Load(const string_type& library, const SearchPaths& attempted_paths,
struct AdbcError* error) {
std::string error_message;
#if defined(_WIN32)
HMODULE handle = LoadLibraryExW(library, NULL, 0);
HMODULE handle = LoadLibraryExW(library.c_str(), NULL, 0);
if (!handle) {
error_message = "Could not load `";
error_message += Utf8Encode(library);
error_message += ": LoadLibraryExW() failed: ";
error_message += "`: LoadLibraryExW() failed: ";
GetWinError(&error_message);

std::wstring full_driver_name = library;
Expand All @@ -909,6 +941,12 @@ struct ManagedLibrary {
}
}
if (!handle) {
std::string name = Utf8Encode(library);
std::string message = CheckNonPrintableLibraryName(name);
if (!message.empty()) {
error_message += "\n";
error_message += message;
}
AddSearchPathsToError(attempted_paths, error_message);
SetError(error, error_message);
return ADBC_STATUS_NOT_FOUND;
Expand All @@ -923,9 +961,11 @@ struct ManagedLibrary {
static const std::string kPlatformLibrarySuffix = ".so";
#endif // defined(__APPLE__)

void* handle = dlopen(library, RTLD_NOW | RTLD_LOCAL);
void* handle = dlopen(library.c_str(), RTLD_NOW | RTLD_LOCAL);
if (!handle) {
error_message = "dlopen() failed: ";
error_message = "Could not load `";
error_message += library;
error_message += "`: dlopen() failed: ";
error_message += dlerror();

// If applicable, append the shared library prefix/extension and
Expand Down Expand Up @@ -955,6 +995,11 @@ struct ManagedLibrary {
if (handle) {
this->handle = handle;
} else {
std::string message = CheckNonPrintableLibraryName(library);
if (!message.empty()) {
error_message += "\n";
error_message += message;
}
AddSearchPathsToError(attempted_paths, error_message);
SetError(error, error_message);
return ADBC_STATUS_NOT_FOUND;
Expand Down
26 changes: 24 additions & 2 deletions c/driver_manager/adbc_driver_manager_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1277,11 +1277,33 @@ shared = "adbc_driver_sqlite")";
ASSERT_THAT(AdbcDriverManagerDatabaseSetAdditionalSearchPathList(
&database.value, search_path.data(), &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcDatabaseInit(&database.value, &error),
IsStatus(ADBC_STATUS_OK, &error));
ASSERT_THAT(AdbcDatabaseInit(&database.value, &error), IsOkStatus(&error));
}

ASSERT_TRUE(std::filesystem::remove(filepath));
}

TEST_F(DriverManifest, ControlCodes) {
const std::string uri = "\026sqlite:file::memory:";
for (const auto& driver_option : {"driver", "uri"}) {
SCOPED_TRACE(driver_option);
SCOPED_TRACE(uri);
adbc_validation::Handle<struct AdbcDatabase> database;
ASSERT_THAT(AdbcDatabaseNew(&database.value, &error), IsOkStatus(&error));
ASSERT_THAT(
AdbcDatabaseSetOption(&database.value, driver_option, uri.c_str(), &error),
IsOkStatus(&error));
std::string search_path = temp_dir.string();
ASSERT_THAT(AdbcDriverManagerDatabaseSetAdditionalSearchPathList(
&database.value, search_path.data(), &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcDatabaseInit(&database.value, &error),
IsStatus(ADBC_STATUS_NOT_FOUND, &error));
ASSERT_THAT(
error.message,
::testing::HasSubstr(
"Note: driver name may have non-printable characters: `\\x16sqlite`"));
}
}

} // namespace adbc
Loading
Loading