Skip to content

Commit 287c677

Browse files
MazterQyouigorlukanin
authored andcommitted
feat(cubesql): SET ROLE changes authentication context (#9982)
1 parent 0289943 commit 287c677

File tree

10 files changed

+144
-77
lines changed

10 files changed

+144
-77
lines changed

rust/cubesql/cubesql/e2e/tests/postgres.rs

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1302,16 +1302,6 @@ impl AsyncTestSuite for PostgresIntegrationTestSuite {
13021302
)
13031303
.await?;
13041304

1305-
self.test_simple_query(r#"SET ROLE "cube""#.to_string(), |messages| {
1306-
assert_eq!(messages.len(), 1);
1307-
1308-
// SET
1309-
if let SimpleQueryMessage::Row(_) = messages[0] {
1310-
panic!("Must be CommandComplete command, (SET is used)")
1311-
}
1312-
})
1313-
.await?;
1314-
13151305
// Tableau Desktop
13161306
self.test_simple_query(
13171307
r#"SET DateStyle = 'ISO';SET extra_float_digits = 2;show transaction_isolation"#

rust/cubesql/cubesql/src/compile/mod.rs

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5177,15 +5177,41 @@ ORDER BY
51775177
);
51785178

51795179
insta::assert_snapshot!(
5180-
"pg_set_role_show",
5180+
"pg_set_role_good_user",
51815181
execute_queries_with_flags(
5182-
vec!["SET ROLE NONE".to_string(), "SHOW ROLE".to_string()],
5182+
vec!["SET ROLE good_user".to_string(), "SHOW ROLE".to_string()],
51835183
DatabaseProtocol::PostgreSQL
51845184
)
51855185
.await?
51865186
.0
51875187
);
51885188

5189+
insta::assert_snapshot!(
5190+
"pg_set_role_none",
5191+
execute_queries_with_flags(
5192+
vec![
5193+
"SET ROLE good_user".to_string(),
5194+
"SET ROLE NONE".to_string(),
5195+
"SHOW ROLE".to_string()
5196+
],
5197+
DatabaseProtocol::PostgreSQL
5198+
)
5199+
.await?
5200+
.0
5201+
);
5202+
5203+
insta::assert_snapshot!(
5204+
"pg_set_role_bad_user",
5205+
execute_queries_with_flags(
5206+
vec!["SET ROLE bad_user".to_string()],
5207+
DatabaseProtocol::PostgreSQL
5208+
)
5209+
.await
5210+
.err()
5211+
.unwrap()
5212+
.to_string()
5213+
);
5214+
51895215
Ok(())
51905216
}
51915217

rust/cubesql/cubesql/src/compile/router.rs

Lines changed: 67 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,9 @@ impl QueryRouter {
102102
StatusFlags::empty(),
103103
Box::new(dataframe::DataFrame::new(vec![], vec![])),
104104
)),
105-
(ast::Statement::SetRole { role_name, .. }, _) => self.set_role_to_plan(role_name),
105+
(ast::Statement::SetRole { role_name, .. }, _) => {
106+
self.set_role_to_plan(role_name).await
107+
}
106108
(ast::Statement::SetVariable { key_values }, _) => {
107109
self.set_variable_to_plan(&key_values).await
108110
}
@@ -283,19 +285,24 @@ impl QueryRouter {
283285
}
284286
}
285287

286-
fn set_role_to_plan(
288+
async fn set_role_to_plan(
287289
&self,
288290
role_name: &Option<ast::Ident>,
289291
) -> Result<QueryPlan, CompilationError> {
290292
let flags = StatusFlags::SERVER_STATE_CHANGED;
291-
let role_name = role_name
292-
.as_ref()
293-
.map(|role_name| role_name.value.clone())
294-
.unwrap_or("none".to_string());
295-
let variable =
296-
DatabaseVariable::system("role".to_string(), ScalarValue::Utf8(Some(role_name)), None);
293+
let username = role_name.as_ref().map(|role_name| role_name.value.clone());
294+
let Some(to_user) = username.clone().or_else(|| self.state.original_user()) else {
295+
return Err(CompilationError::user(
296+
"Cannot reset role when original role has not been set".to_string(),
297+
));
298+
};
299+
self.change_user(to_user).await?;
300+
let variable = DatabaseVariable::system(
301+
"role".to_string(),
302+
ScalarValue::Utf8(Some(username.unwrap_or("none".to_string()))),
303+
None,
304+
);
297305
self.state.set_variables(vec![variable]);
298-
299306
Ok(QueryPlan::MetaOk(flags, CommandCompletion::Set))
300307
}
301308

@@ -419,11 +426,6 @@ impl QueryRouter {
419426
});
420427

421428
for v in user_variables {
422-
self.reauthenticate_if_needed().await?;
423-
424-
let auth_context = self.state.auth_context().ok_or(CompilationError::user(
425-
"No auth context set but tried to set current user".to_string(),
426-
))?;
427429
let to_user = match v.value {
428430
ScalarValue::Utf8(Some(user)) => user,
429431
_ => {
@@ -433,46 +435,7 @@ impl QueryRouter {
433435
)))
434436
}
435437
};
436-
if self
437-
.session_manager
438-
.server
439-
.transport
440-
.can_switch_user_for_session(auth_context.clone(), to_user.clone())
441-
.await
442-
.map_err(|e| {
443-
CompilationError::internal(format!(
444-
"Error calling can_switch_user_for_session: {}",
445-
e
446-
))
447-
})?
448-
{
449-
self.state.set_user(Some(to_user.clone()));
450-
let sql_auth_request = SqlAuthServiceAuthenticateRequest {
451-
protocol: "postgres".to_string(),
452-
method: "password".to_string(),
453-
};
454-
let authenticate_response = self
455-
.session_manager
456-
.server
457-
.auth
458-
.authenticate(sql_auth_request, Some(to_user.clone()), None)
459-
.await
460-
.map_err(|e| {
461-
CompilationError::internal(format!("Error calling authenticate: {}", e))
462-
})?;
463-
self.state
464-
.set_auth_context(Some(authenticate_response.context));
465-
} else {
466-
return Err(CompilationError::user(format!(
467-
"user '{}' is not allowed to switch to '{}'",
468-
auth_context
469-
.user()
470-
.as_ref()
471-
.map(|v| v.as_str())
472-
.unwrap_or("not specified"),
473-
to_user
474-
)));
475-
}
438+
self.change_user(to_user).await?;
476439
}
477440

478441
if !session_columns_to_update.is_empty() {
@@ -488,6 +451,56 @@ impl QueryRouter {
488451
Ok(QueryPlan::MetaOk(flags, CommandCompletion::Set))
489452
}
490453

454+
async fn change_user(&self, username: String) -> Result<(), CompilationError> {
455+
self.reauthenticate_if_needed().await?;
456+
457+
let auth_context = self.state.auth_context().ok_or(CompilationError::user(
458+
"No auth context set but tried to set current user".to_string(),
459+
))?;
460+
461+
let can_switch_user = self
462+
.session_manager
463+
.server
464+
.transport
465+
.can_switch_user_for_session(auth_context.clone(), username.clone())
466+
.await
467+
.map_err(|e| {
468+
CompilationError::internal(format!(
469+
"Error calling can_switch_user_for_session: {}",
470+
e
471+
))
472+
})?;
473+
if !can_switch_user {
474+
return Err(CompilationError::user(format!(
475+
"user '{}' is not allowed to switch to '{}'",
476+
auth_context
477+
.user()
478+
.as_ref()
479+
.map(|v| v.as_str())
480+
.unwrap_or("not specified"),
481+
username
482+
)));
483+
}
484+
485+
self.state.set_user(Some(username.clone()));
486+
let sql_auth_request = SqlAuthServiceAuthenticateRequest {
487+
protocol: "postgres".to_string(),
488+
method: "password".to_string(),
489+
};
490+
let authenticate_response = self
491+
.session_manager
492+
.server
493+
.auth
494+
.authenticate(sql_auth_request, Some(username), None)
495+
.await
496+
.map_err(|e| {
497+
CompilationError::internal(format!("Error calling authenticate: {}", e))
498+
})?;
499+
self.state
500+
.set_auth_context(Some(authenticate_response.context));
501+
Ok(())
502+
}
503+
491504
async fn create_table_to_plan(
492505
&self,
493506
name: &ast::ObjectName,
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
source: cubesql/src/compile/mod.rs
3+
expression: "execute_queries_with_flags(vec![\"SET ROLE bad_user\".to_string(),],\nDatabaseProtocol::PostgreSQL).await.err().unwrap().to_string()"
4+
---
5+
Error during planning: SQLCompilationError: User: user 'not specified' is not allowed to switch to 'bad_user'
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
---
2+
source: cubesql/src/compile/mod.rs
3+
expression: "execute_queries_with_flags(vec![\"SET ROLE good_user\".to_string(),\n\"SHOW ROLE\".to_string()], DatabaseProtocol::PostgreSQL).await? .0"
4+
---
5+
+-----------+
6+
| setting |
7+
+-----------+
8+
| good_user |
9+
+-----------+
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
---
2+
source: cubesql/src/compile/mod.rs
3+
expression: "execute_queries_with_flags(vec![\"SET ROLE good_user\".to_string(),\n\"SET ROLE NONE\".to_string(), \"SHOW ROLE\".to_string()],\nDatabaseProtocol::PostgreSQL).await? .0"
4+
---
5+
+---------+
6+
| setting |
7+
+---------+
8+
| none |
9+
+---------+

rust/cubesql/cubesql/src/compile/snapshots/cubesql__compile__tests__pg_set_role_show.snap

Lines changed: 0 additions & 9 deletions
This file was deleted.

rust/cubesql/cubesql/src/compile/test/mod.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -771,6 +771,7 @@ async fn get_test_session_with_config_and_transport(
771771
// Populate like shims
772772
session.state.set_database(Some(db_name.to_string()));
773773
session.state.set_user(Some("ovr".to_string()));
774+
session.state.set_original_user(Some("ovr".to_string()));
774775

775776
let auth_ctx = HttpAuthContext {
776777
access_token: "access_token".to_string(),
@@ -938,7 +939,7 @@ impl TransportService for TestConnectionTransport {
938939
_ctx: AuthContextRef,
939940
to_user: String,
940941
) -> Result<bool, CubeError> {
941-
if to_user == "good_user" {
942+
if matches!(to_user.as_str(), "good_user" | "ovr") {
942943
Ok(true)
943944
} else {
944945
Ok(false)

rust/cubesql/cubesql/src/sql/postgres/shim.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -806,7 +806,8 @@ impl AsyncPostgresShim {
806806
.cloned()
807807
.unwrap_or("db".to_string());
808808
self.session.state.set_database(Some(database));
809-
self.session.state.set_user(Some(user));
809+
self.session.state.set_user(Some(user.clone()));
810+
self.session.state.set_original_user(Some(user));
810811
self.session.state.set_auth_context(Some(auth_context));
811812

812813
self.write(protocol::Authentication::new(AuthenticationRequest::Ok))

rust/cubesql/cubesql/src/sql/session.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ pub struct SessionState {
8282
// @todo Remove RWLock after split of Connection & SQLWorker
8383
// Context for Transport
8484
auth_context: RwLockSync<(Option<AuthContextRef>, SystemTime)>,
85+
// Used to reset user with SET ROLE NONE
86+
original_user: RwLockSync<Option<String>>,
8587

8688
transaction: RwLockSync<TransactionState>,
8789
query: RwLockSync<QueryState>,
@@ -116,6 +118,7 @@ impl SessionState {
116118
temp_tables: Arc::new(TempTableManager::new(session_manager)),
117119
properties: RwLockSync::new(SessionProperties::new(None, None)),
118120
auth_context: RwLockSync::new((auth_context, SystemTime::now())),
121+
original_user: RwLockSync::new(None),
119122
transaction: RwLockSync::new(TransactionState::None),
120123
query: RwLockSync::new(QueryState::None),
121124
statements: RWLockAsync::new(HashMap::new()),
@@ -271,6 +274,25 @@ impl SessionState {
271274
guard.user = user;
272275
}
273276

277+
pub fn original_user(&self) -> Option<String> {
278+
let guard = self
279+
.original_user
280+
.read()
281+
.expect("failed to unlock original_user for reading");
282+
guard.clone()
283+
}
284+
285+
pub fn set_original_user(&self, user: Option<String>) {
286+
let mut guard = self
287+
.original_user
288+
.write()
289+
.expect("failed to unlock original_user for writing");
290+
if guard.is_none() {
291+
// Silently ignore writing original user if it's already set
292+
*guard = user;
293+
}
294+
}
295+
274296
pub fn database(&self) -> Option<String> {
275297
let guard = self
276298
.properties

0 commit comments

Comments
 (0)