Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 5 additions & 16 deletions contrib/babelfishpg_tds/src/backend/tds/tdslogin.c
Original file line number Diff line number Diff line change
Expand Up @@ -2033,10 +2033,9 @@ TdsProcessLogin(Port *port, bool loadedSsl)
void
TdsSetDbContext()
{
char *dbname = NULL;
char *useDbCommand = NULL;
char *user = NULL;
MemoryContext oldContext = CurrentMemoryContext;
char *dbname = NULL;
char *user = NULL;
MemoryContext oldContext = CurrentMemoryContext;

PG_TRY();
{
Expand All @@ -2059,11 +2058,6 @@ TdsSetDbContext()
(errcode(ERRCODE_UNDEFINED_DATABASE),
errmsg("database \"%s\" does not exist", loginInfo->database)));

/*
* Any delimitated/quoted db name identifier requested in login
* must be already handled before this point.
*/
useDbCommand = psprintf("USE [%s]", loginInfo->database);
dbname = pstrdup(loginInfo->database);
}
else
Expand All @@ -2078,8 +2072,6 @@ TdsSetDbContext()
ereport(ERROR,
(errcode(ERRCODE_UNDEFINED_DATABASE),
errmsg("could not find default database for user \"%s\"", loginInfo->username)));

useDbCommand = psprintf("USE [%s]", temp);
dbname = pstrdup(temp);
CommitTransactionCommand();
MemoryContextSwitchTo(oldContext);
Expand All @@ -2096,10 +2088,9 @@ TdsSetDbContext()
errmsg("Cannot open database \"%s\" requested by the login. The login failed", dbname)));

/*
* loginInfo has a database name provided, so we execute a "USE
* [<db_name>]" through pltsql inline handler.
* Direct API invokation for switching database context.
*/
ExecuteSQLBatch(useDbCommand);
pltsql_plugin_handler_ptr->switch_database_context(dbname);
CommitTransactionCommand();
}
PG_CATCH();
Expand Down Expand Up @@ -2136,8 +2127,6 @@ TdsSetDbContext()
PG_RE_THROW();
}
PG_END_TRY();
if (useDbCommand)
pfree(useDbCommand);
if (dbname)
pfree(dbname);
}
Expand Down
1 change: 1 addition & 0 deletions contrib/babelfishpg_tsql/src/pl_handler.c
Original file line number Diff line number Diff line change
Expand Up @@ -3695,6 +3695,7 @@ _PG_init(void)
(*pltsql_protocol_plugin_ptr)->pltsql_get_logical_schema_name = &get_logical_schema_name;
(*pltsql_protocol_plugin_ptr)->pltsql_is_fmtonly_stmt = &pltsql_fmtonly;
(*pltsql_protocol_plugin_ptr)->pltsql_get_user_for_database = &get_user_for_database;
(*pltsql_protocol_plugin_ptr)->switch_database_context = &switch_database_context;
(*pltsql_protocol_plugin_ptr)->get_insert_bulk_rows_per_batch = &get_insert_bulk_rows_per_batch;
(*pltsql_protocol_plugin_ptr)->get_insert_bulk_kilobytes_per_batch = &get_insert_bulk_kilobytes_per_batch;
(*pltsql_protocol_plugin_ptr)->tsql_varchar_input = common_utility_plugin_ptr->tsql_varchar_input;
Expand Down
2 changes: 2 additions & 0 deletions contrib/babelfishpg_tsql/src/pltsql.h
Original file line number Diff line number Diff line change
Expand Up @@ -1649,6 +1649,8 @@ typedef struct PLtsql_protocol_plugin

char *(*pltsql_get_user_for_database) (const char *db_name);

void (*switch_database_context) (const char *dbname);

char *(*TsqlEncodingConversion) (const char *s, int len, int encoding, int *encodedByteLen);

int (*TdsGetEncodingFromLcid) (int32_t lcid);
Expand Down
34 changes: 33 additions & 1 deletion contrib/babelfishpg_tsql/src/session.c
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,38 @@ set_cur_user_db_and_path(const char *db_name)
set_search_path_for_user_schema(db_name, user);
}

/*
* switch_database_context - Switching database context during login
*
* This function performs all necessary steps to securely switch database context:
* 1. Validates user has access to the database
* 2. Acquires session-level lock on the target database
* 3. Sets the database, user, and search path
*
*/
void
switch_database_context(const char *dbname)
{
int16 new_db_id;

/* Validate user has access to the database */
check_session_db_access(dbname);

/* Get database ID for lock acquisition */
new_db_id = get_db_id(dbname);

/* Acquire session-level lock on the new database */
if (!TryLockLogicalDatabaseForSession(new_db_id, ShareLock))
ereport(ERROR,
(errcode(ERRCODE_INTERNAL_ERROR),
errmsg("Cannot use database \"%s\", failed to obtain lock. "
"\"%s\" is probably undergoing DDL statements in another session.",
dbname, dbname)));

/* Set database context, user, and search path */
set_cur_user_db_and_path(dbname);
}

static void
set_search_path_for_user_schema(const char *db_name, const char *user)
{
Expand Down Expand Up @@ -456,4 +488,4 @@ babelfixedparallelstate_restore(shm_toc *toc)

/* Set the logcial db name for parallel workers */
set_cur_db_name_for_parallel_worker(bfps->logical_db_name);
}
}
1 change: 1 addition & 0 deletions contrib/babelfishpg_tsql/src/session.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ extern void set_cur_user_db_and_path(const char *db_name);
extern void restore_session_properties(void);
extern void reset_session_properties(void);
extern void set_cur_db_name_for_parallel_worker(const char* logical_db_name);
extern void switch_database_context(const char *dbname);

/* Hooks for parallel workers for babelfish fixed state */
extern void babelfixedparallelstate_insert(ParallelContext *pcxt, bool estimate);
Expand Down