diff --git a/contrib/babelfishpg_tds/src/backend/tds/tdslogin.c b/contrib/babelfishpg_tds/src/backend/tds/tdslogin.c index 9c648bd0953..d22eea0a138 100644 --- a/contrib/babelfishpg_tds/src/backend/tds/tdslogin.c +++ b/contrib/babelfishpg_tds/src/backend/tds/tdslogin.c @@ -1930,7 +1930,6 @@ TdsSendLoginAck(Port *port) uint8 temp8; uint32_t collationInfo; char collationBytesNew[5]; - char *useDbCommand = NULL; MemoryContext oldContext; uint32_t tdsVersion = pg_hton32(loginInfo->tdsVersion); @@ -2057,8 +2056,6 @@ TdsSendLoginAck(Port *port) (errcode(ERRCODE_UNDEFINED_DATABASE), errmsg("database \"%s\" does not exist", request->database))); - /* Any delimitated/quoted db name identifier requested in login must be already handled before this point. */ - useDbCommand = psprintf("USE [%s]", request->database); } else { @@ -2072,8 +2069,7 @@ TdsSendLoginAck(Port *port) ereport(ERROR, (errcode(ERRCODE_UNDEFINED_DATABASE), errmsg("could not find default database for user \"%s\"", port->user_name))); - - useDbCommand = psprintf("USE [%s]", temp); + dbname = pstrdup(temp); CommitTransactionCommand(); MemoryContextSwitchTo(oldContext); } @@ -2083,10 +2079,10 @@ TdsSendLoginAck(Port *port) * a "USE []" through pgtsql inline handler */ StartTransactionCommand(); - ExecuteSQLBatch(useDbCommand); + pltsql_plugin_handler_ptr->switch_database_context(dbname); CommitTransactionCommand(); - if (useDbCommand) - pfree(useDbCommand); + if (dbname) + pfree(dbname); /* * Set the GUC for language, it will take care of diff --git a/contrib/babelfishpg_tsql/src/pl_handler.c b/contrib/babelfishpg_tsql/src/pl_handler.c index 88e2f8ab843..6880a6b2be8 100644 --- a/contrib/babelfishpg_tsql/src/pl_handler.c +++ b/contrib/babelfishpg_tsql/src/pl_handler.c @@ -3200,6 +3200,7 @@ _PG_init(void) (*pltsql_protocol_plugin_ptr)->pltsql_is_login = &is_login; (*pltsql_protocol_plugin_ptr)->pltsql_get_generic_typmod = &probin_read_ret_typmod; (*pltsql_protocol_plugin_ptr)->pltsql_is_fmtonly_stmt = &pltsql_fmtonly; + (*pltsql_protocol_plugin_ptr)->switch_database_context = &switch_database_context; } *pltsql_config_ptr = &myConfig; diff --git a/contrib/babelfishpg_tsql/src/pltsql.h b/contrib/babelfishpg_tsql/src/pltsql.h index 864e4955865..3a0c0688429 100644 --- a/contrib/babelfishpg_tsql/src/pltsql.h +++ b/contrib/babelfishpg_tsql/src/pltsql.h @@ -1561,6 +1561,8 @@ typedef struct PLtsql_protocol_plugin bool *pltsql_is_fmtonly_stmt; + void* (*switch_database_context) (const char *dbname); + } PLtsql_protocol_plugin; /* diff --git a/contrib/babelfishpg_tsql/src/session.c b/contrib/babelfishpg_tsql/src/session.c index 501067077a3..595a432e223 100644 --- a/contrib/babelfishpg_tsql/src/session.c +++ b/contrib/babelfishpg_tsql/src/session.c @@ -116,6 +116,32 @@ set_session_properties(const char *db_name) PGC_S_DATABASE_USER); } +/* + * switch_database_context - Switching database context during login + * + * This function performs all necessary steps to securely switch database context: + * 1. Validates user has access to the database + * 2. Acquires session-level lock on the target database + * 3. Sets the database, user, and search path + * + */ +void +switch_database_context(const char *dbname) +{ + int16 new_db_id; + + /* Acquire session-level lock on the new database */ + if (!TryLockLogicalDatabaseForSession(new_db_id, ShareLock)) + ereport(ERROR, + (errcode(ERRCODE_INTERNAL_ERROR), + errmsg("Cannot use database \"%s\", failed to obtain lock. " + "\"%s\" is probably undergoing DDL statements in another session.", + dbname, dbname))); + + /* Set database context, user, and search path */ + set_session_properties(dbname); +} + /* * Wrapper function to reset the session properties and cached batch * incase of a reset connection. diff --git a/contrib/babelfishpg_tsql/src/session.h b/contrib/babelfishpg_tsql/src/session.h index 3dc3ff11682..7ee0885e559 100644 --- a/contrib/babelfishpg_tsql/src/session.h +++ b/contrib/babelfishpg_tsql/src/session.h @@ -9,5 +9,6 @@ extern void bbf_set_current_user(const char *user_name); extern void set_session_properties(const char *db_name); extern void restore_session_properties(void); extern void reset_session_properties(void); +extern void switch_database_context(const char *dbname); #endif