diff --git a/source/pdo_sqlsrv/config.w32 b/source/pdo_sqlsrv/config.w32 index 0ec718f58..e3e9ca494 100644 --- a/source/pdo_sqlsrv/config.w32 +++ b/source/pdo_sqlsrv/config.w32 @@ -27,10 +27,13 @@ if( PHP_PDO_SQLSRV != "no" ) { if (CHECK_LIB("odbc32.lib", "pdo_sqlsrv") && CHECK_LIB("odbccp32.lib", "pdo_sqlsrv") && CHECK_LIB("version.lib", "pdo_sqlsrv") && CHECK_LIB("psapi.lib", "pdo_sqlsrv")&& CHECK_HEADER_ADD_INCLUDE( "core_sqlsrv.h", "CFLAGS_PDO_SQLSRV", configure_module_dirname + "\\shared")) { + CHECK_HEADER_ADD_INCLUDE("php_pdo_sqlsrv_int.h", "CFLAGS_PDO_SQLSRV", configure_module_dirname); + CHECK_HEADER_ADD_INCLUDE("php_pdo_sqlsrv.h", "CFLAGS_PDO_SQLSRV", configure_module_dirname); CHECK_HEADER_ADD_INCLUDE("sql.h", "CFLAGS_PDO_SQLSRV_ODBC"); CHECK_HEADER_ADD_INCLUDE("sqlext.h", "CFLAGS_PDO_SQLSRV_ODBC"); ADD_SOURCES( configure_module_dirname + "\\shared", shared_src_class, "pdo_sqlsrv" ); ADD_FLAG( "LDFLAGS_PDO_SQLSRV", "/NXCOMPAT /DYNAMICBASE /debug /guard:cf" ); + ADD_FLAG( "CFLAGS_PDO_SQLSRV", "/D PDO_SQLSRV" ); ADD_FLAG( "CFLAGS_PDO_SQLSRV", "/EHsc" ); ADD_FLAG( "CFLAGS_PDO_SQLSRV", "/GS" ); ADD_FLAG( "CFLAGS_PDO_SQLSRV", "/Zi" ); diff --git a/source/pdo_sqlsrv/pdo_init.cpp b/source/pdo_sqlsrv/pdo_init.cpp index 3d7d23953..f6da9ff95 100644 --- a/source/pdo_sqlsrv/pdo_init.cpp +++ b/source/pdo_sqlsrv/pdo_init.cpp @@ -249,7 +249,15 @@ PHP_RSHUTDOWN_FUNCTION(pdo_sqlsrv) { // SQLSRV_UNUSED( module_number ); // SQLSRV_UNUSED( type ); - +#if defined(_WIN32) && !defined(ZTS) + for (size_t current_token = 0; current_token < PDO_SQLSRV_G(access_tokens_size); current_token++) { + if (PDO_SQLSRV_G(access_tokens)[current_token]) { + memset(PDO_SQLSRV_G(access_tokens)[current_token]->data, 0, PDO_SQLSRV_G(access_tokens)[current_token]->dataSize); + sqlsrv_free(PDO_SQLSRV_G(access_tokens)[current_token]); + } + } + sqlsrv_free(PDO_SQLSRV_G(access_tokens)); +#endif PDO_LOG_NOTICE("pdo_sqlsrv: entering rshutdown"); return SUCCESS; diff --git a/source/pdo_sqlsrv/php_pdo_sqlsrv.h b/source/pdo_sqlsrv/php_pdo_sqlsrv.h index 923eb20da..dd00d695e 100644 --- a/source/pdo_sqlsrv/php_pdo_sqlsrv.h +++ b/source/pdo_sqlsrv/php_pdo_sqlsrv.h @@ -22,6 +22,9 @@ #include "php.h" +#ifdef _WIN32 +#include "msodbcsql.h" +#endif //********************************************************************************************************************************* // Global variables //********************************************************************************************************************************* @@ -35,6 +38,9 @@ short report_additional_errors; #ifndef _WIN32 zend_long set_locale_info; +#else +ACCESSTOKEN** access_tokens; +unsigned int access_tokens_size = 0; #endif ZEND_END_MODULE_GLOBALS(pdo_sqlsrv) diff --git a/source/shared/core_conn.cpp b/source/shared/core_conn.cpp index a5f87ba38..41ba63449 100644 --- a/source/shared/core_conn.cpp +++ b/source/shared/core_conn.cpp @@ -32,6 +32,18 @@ #ifndef _WIN32 #include #include +#else +#ifdef PDO_SQLSRV +extern "C" { +#include "php_pdo_sqlsrv.h" +} +#include "php_pdo_sqlsrv_int.h" +#elif SQLSRV +extern "C" { +#include "php_sqlsrv.h" +} +#include "php_sqlsrv_int.h" +#endif #endif // *** internal variables and constants *** @@ -220,12 +232,13 @@ sqlsrv_conn* core_sqlsrv_connect( _In_ sqlsrv_context& henv_cp, _In_ sqlsrv_cont } } +#ifdef ZTS // time to free the access token, if not null if (conn->azure_ad_access_token) { memset(conn->azure_ad_access_token->data, 0, conn->azure_ad_access_token->dataSize); // clear the memory conn->azure_ad_access_token.reset(); } - +#endif CHECK_SQL_ERROR( r, conn ) { throw core::CoreException(); } @@ -1153,7 +1166,17 @@ size_t core_str_zval_is_true(_Inout_ zval* value_z) return 0; // false } -void access_token_set_func::func( _In_ connection_option const* option, _In_ zval* value, _Inout_ sqlsrv_conn* conn, _Inout_ std::string& conn_str ) +#if defined(_WIN32) && !defined(ZTS) +ACCESSTOKEN** get_access_tokens() { +#ifdef PDO_SQLSRV + return PDO_SQLSRV_G(access_tokens); +#elif SQLSRV + return SQLSRV_G(access_tokens); +#endif +} +#endif + +void access_token_set_func::func(_In_ connection_option const* option, _In_ zval* value, _Inout_ sqlsrv_conn* conn, _Inout_ std::string& conn_str) { SQLSRV_ASSERT(Z_TYPE_P(value) == IS_STRING, "An access token must be a byte string."); @@ -1182,13 +1205,36 @@ void access_token_set_func::func( _In_ connection_option const* option, _In_ zva // similar to a UCS-2 string containing only ASCII characters // // See https://docs.microsoft.com/sql/connect/odbc/using-azure-active-directory#authenticating-with-an-access-token +#if defined(_WIN32) && !defined(ZTS) + size_t next_token_position = 0; + bool same_token_used = false; + #ifdef PDO_SQLSRV + unsigned int& access_tokens_size = PDO_SQLSRV_G(access_tokens_size); + #elif SQLSRV + unsigned int& access_tokens_size = SQLSRV_G(access_tokens_size); + #endif + + for (size_t current_token_index = 0; current_token_index < access_tokens_size; current_token_index++) { + std::string string_token; + for (size_t i = 0; i < get_access_tokens()[current_token_index]->dataSize; i += 2) { + string_token.push_back(get_access_tokens()[current_token_index]->data[i]); + } + if (string_token == std::string(value_str)) { + // Token already exists in access_toiens + memset(get_access_tokens()[current_token_index]->data, 0, get_access_tokens()[current_token_index]->dataSize); + sqlsrv_free(get_access_tokens()[current_token_index]); + next_token_position = current_token_index; + same_token_used = true; + break; + } + } +#endif size_t dataSize = 2 * value_len; - sqlsrv_malloc_auto_ptr accToken; accToken = reinterpret_cast(sqlsrv_malloc(sizeof(ACCESSTOKEN) + dataSize)); - ACCESSTOKEN *pAccToken = accToken.get(); + ACCESSTOKEN* pAccToken = accToken.get(); SQLSRV_ASSERT(pAccToken != NULL, "Something went wrong when trying to allocate memory for the access token."); pAccToken->dataSize = dataSize; @@ -1196,7 +1242,7 @@ void access_token_set_func::func( _In_ connection_option const* option, _In_ zva // Expand access token with padding bytes for (size_t i = 0, j = 0; i < dataSize; i += 2, j++) { pAccToken->data[i] = value_str[j]; - pAccToken->data[i+1] = 0; + pAccToken->data[i + 1] = 0; } core::SQLSetConnectAttr(conn, SQL_COPT_SS_ACCESS_TOKEN, reinterpret_cast(pAccToken), SQL_IS_POINTER); @@ -1204,4 +1250,16 @@ void access_token_set_func::func( _In_ connection_option const* option, _In_ zva // Save the pointer because SQLDriverConnect() will use it to make connection to the server conn->azure_ad_access_token = pAccToken; accToken.transferred(); +#if defined(_WIN32) && !defined(ZTS) + if (!same_token_used) { + next_token_position = access_tokens_size; + access_tokens_size++; + #ifdef PDO_SQLSRV + PDO_SQLSRV_G(access_tokens) = reinterpret_cast(sqlsrv_realloc(PDO_SQLSRV_G(access_tokens), access_tokens_size * sizeof(ACCESSTOKEN*))); + #elif SQLSRV + SQLSRV_G(access_tokens) = reinterpret_cast(sqlsrv_realloc(SQLSRV_G(access_tokens), access_tokens_size * sizeof(ACCESSTOKEN*))); + #endif + } + get_access_tokens()[next_token_position] = conn->azure_ad_access_token; +#endif } diff --git a/source/sqlsrv/config.w32 b/source/sqlsrv/config.w32 index 68a2d2207..cf6c4e8d9 100644 --- a/source/sqlsrv/config.w32 +++ b/source/sqlsrv/config.w32 @@ -32,7 +32,10 @@ if( PHP_SQLSRV != "no" ) { } CHECK_HEADER_ADD_INCLUDE("sql.h", "CFLAGS_SQLSRV_ODBC"); CHECK_HEADER_ADD_INCLUDE("sqlext.h", "CFLAGS_SQLSRV_ODBC"); + CHECK_HEADER_ADD_INCLUDE("php_sqlsrv_int.h", "CFLAGS_SQLSRV", configure_module_dirname); + CHECK_HEADER_ADD_INCLUDE("php_sqlsrv.h", "CFLAGS_SQLSRV", configure_module_dirname); ADD_FLAG( "LDFLAGS_SQLSRV", "/NXCOMPAT /DYNAMICBASE /debug /guard:cf" ); + ADD_FLAG( "CFLAGS_SQLSRV", "/D SQLSRV" ); ADD_FLAG( "CFLAGS_SQLSRV", "/D ZEND_WIN32_FORCE_INLINE" ); ADD_FLAG( "CFLAGS_SQLSRV", "/EHsc" ); ADD_FLAG( "CFLAGS_SQLSRV", "/GS" ); diff --git a/source/sqlsrv/init.cpp b/source/sqlsrv/init.cpp index a6f6e711d..5386cbfa2 100644 --- a/source/sqlsrv/init.cpp +++ b/source/sqlsrv/init.cpp @@ -696,6 +696,15 @@ PHP_RSHUTDOWN_FUNCTION(sqlsrv) // SQLSRV_UNUSED( type ); LOG_FUNCTION( "PHP_RSHUTDOWN for php_sqlsrv" ); +#if defined(_WIN32) && !defined(ZTS) + for (size_t current_token = 0; current_token < SQLSRV_G(access_tokens_size); current_token++) { + if (SQLSRV_G(access_tokens)[current_token]) { + memset(SQLSRV_G(access_tokens)[current_token]->data, 0, SQLSRV_G(access_tokens)[current_token]->dataSize); + sqlsrv_free(SQLSRV_G(access_tokens)[current_token]); + } + } + sqlsrv_free(SQLSRV_G(access_tokens)); +#endif reset_errors(); // destruction diff --git a/source/sqlsrv/php_sqlsrv.h b/source/sqlsrv/php_sqlsrv.h index b30b41c6d..336635478 100644 --- a/source/sqlsrv/php_sqlsrv.h +++ b/source/sqlsrv/php_sqlsrv.h @@ -24,6 +24,9 @@ #include "php.h" +#ifdef _WIN32 +#include "msodbcsql.h" +#endif //********************************************************************************************************************************* // Global variables //********************************************************************************************************************************* @@ -44,6 +47,9 @@ zend_long buffered_query_limit; #ifndef _WIN32 zend_long set_locale_info; +#else +ACCESSTOKEN** access_tokens; +unsigned int access_tokens_size = 0; #endif ZEND_END_MODULE_GLOBALS(sqlsrv)