Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions contrib/babelfishpg_tds/src/backend/tds/tdslogin.c
Original file line number Diff line number Diff line change
Expand Up @@ -1930,7 +1930,6 @@ TdsSendLoginAck(Port *port)
uint8 temp8;
uint32_t collationInfo;
char collationBytesNew[5];
char *useDbCommand = NULL;
MemoryContext oldContext;
uint32_t tdsVersion = pg_hton32(loginInfo->tdsVersion);

Expand Down Expand Up @@ -2057,8 +2056,6 @@ TdsSendLoginAck(Port *port)
(errcode(ERRCODE_UNDEFINED_DATABASE),
errmsg("database \"%s\" does not exist", request->database)));

/* Any delimitated/quoted db name identifier requested in login must be already handled before this point. */
useDbCommand = psprintf("USE [%s]", request->database);
}
else
{
Expand All @@ -2072,8 +2069,7 @@ TdsSendLoginAck(Port *port)
ereport(ERROR,
(errcode(ERRCODE_UNDEFINED_DATABASE),
errmsg("could not find default database for user \"%s\"", port->user_name)));

useDbCommand = psprintf("USE [%s]", temp);
dbname = pstrdup(temp);
CommitTransactionCommand();
MemoryContextSwitchTo(oldContext);
}
Expand All @@ -2083,10 +2079,10 @@ TdsSendLoginAck(Port *port)
* a "USE [<db_name>]" through pgtsql inline handler
*/
StartTransactionCommand();
ExecuteSQLBatch(useDbCommand);
pltsql_plugin_handler_ptr->switch_database_context(dbname);
CommitTransactionCommand();
if (useDbCommand)
pfree(useDbCommand);
if (dbname)
pfree(dbname);

/*
* Set the GUC for language, it will take care of
Expand Down
1 change: 1 addition & 0 deletions contrib/babelfishpg_tsql/src/pl_handler.c
Original file line number Diff line number Diff line change
Expand Up @@ -3200,6 +3200,7 @@ _PG_init(void)
(*pltsql_protocol_plugin_ptr)->pltsql_is_login = &is_login;
(*pltsql_protocol_plugin_ptr)->pltsql_get_generic_typmod = &probin_read_ret_typmod;
(*pltsql_protocol_plugin_ptr)->pltsql_is_fmtonly_stmt = &pltsql_fmtonly;
(*pltsql_protocol_plugin_ptr)->switch_database_context = &switch_database_context;
}

*pltsql_config_ptr = &myConfig;
Expand Down
2 changes: 2 additions & 0 deletions contrib/babelfishpg_tsql/src/pltsql.h
Original file line number Diff line number Diff line change
Expand Up @@ -1561,6 +1561,8 @@ typedef struct PLtsql_protocol_plugin

bool *pltsql_is_fmtonly_stmt;

void* (*switch_database_context) (const char *dbname);

} PLtsql_protocol_plugin;

/*
Expand Down
26 changes: 26 additions & 0 deletions contrib/babelfishpg_tsql/src/session.c
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,32 @@ set_session_properties(const char *db_name)
PGC_S_DATABASE_USER);
}

/*
* switch_database_context - Switching database context during login
*
* This function performs all necessary steps to securely switch database context:
* 1. Validates user has access to the database
* 2. Acquires session-level lock on the target database
* 3. Sets the database, user, and search path
*
*/
void
switch_database_context(const char *dbname)
{
int16 new_db_id;

/* Acquire session-level lock on the new database */
if (!TryLockLogicalDatabaseForSession(new_db_id, ShareLock))
ereport(ERROR,
(errcode(ERRCODE_INTERNAL_ERROR),
errmsg("Cannot use database \"%s\", failed to obtain lock. "
"\"%s\" is probably undergoing DDL statements in another session.",
dbname, dbname)));

/* Set database context, user, and search path */
set_session_properties(dbname);
}

/*
* Wrapper function to reset the session properties and cached batch
* incase of a reset connection.
Expand Down
1 change: 1 addition & 0 deletions contrib/babelfishpg_tsql/src/session.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@ extern void bbf_set_current_user(const char *user_name);
extern void set_session_properties(const char *db_name);
extern void restore_session_properties(void);
extern void reset_session_properties(void);
extern void switch_database_context(const char *dbname);

#endif
Loading