Skip to content

Commit 401a845

Browse files
authored
feat: introduce "protocol" and "method" props for request param in checkSqlAuth (#9525)
1 parent fa1d9f4 commit 401a845

File tree

9 files changed

+84
-11
lines changed

9 files changed

+84
-11
lines changed

packages/cubejs-api-gateway/src/sql-server.ts

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ export type SQLServerConstructorOptions = {
3131
gatewayPort?: number,
3232
};
3333

34+
export type SqlAuthServiceAuthenticateRequest = {
35+
protocol: string;
36+
method: string;
37+
};
38+
3439
export class SQLServer {
3540
protected sqlInterfaceInstance: SqlInterfaceInstance | null = null;
3641

@@ -88,10 +93,14 @@ export class SQLServer {
8893
let { securityContext } = session;
8994

9095
if (request.meta.changeUser && request.meta.changeUser !== session.user) {
96+
const sqlAuthRequest: SqlAuthServiceAuthenticateRequest = {
97+
protocol: request.meta.protocol,
98+
method: 'password',
99+
};
91100
const canSwitch = session.superuser || await canSwitchSqlUser(session.user, request.meta.changeUser);
92101
if (canSwitch) {
93102
userForContext = request.meta.changeUser;
94-
const current = await checkSqlAuth(request, userForContext, null);
103+
const current = await checkSqlAuth({ ...request, ...sqlAuthRequest }, userForContext, null);
95104
securityContext = current.securityContext;
96105
} else {
97106
throw new Error(

packages/cubejs-backend-native/src/auth.rs

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use async_trait::async_trait;
22
use cubesql::{
33
di_service,
4-
sql::{AuthContext, AuthenticateResponse, SqlAuthService},
4+
sql::{AuthContext, AuthenticateResponse, SqlAuthService, SqlAuthServiceAuthenticateRequest},
55
transport::LoadRequestMeta,
66
CubeError,
77
};
@@ -50,9 +50,28 @@ pub struct TransportRequest {
5050
pub meta: Option<LoadRequestMeta>,
5151
}
5252

53+
#[derive(Debug, Serialize)]
54+
pub struct TransportAuthRequest {
55+
pub id: String,
56+
pub meta: Option<LoadRequestMeta>,
57+
pub protocol: String,
58+
pub method: String,
59+
}
60+
61+
impl From<(TransportRequest, SqlAuthServiceAuthenticateRequest)> for TransportAuthRequest {
62+
fn from((t, a): (TransportRequest, SqlAuthServiceAuthenticateRequest)) -> Self {
63+
Self {
64+
id: t.id,
65+
meta: t.meta,
66+
protocol: a.protocol,
67+
method: a.method,
68+
}
69+
}
70+
}
71+
5372
#[derive(Debug, Serialize)]
5473
struct CheckSQLAuthTransportRequest {
55-
request: TransportRequest,
74+
request: TransportAuthRequest,
5675
user: Option<String>,
5776
password: Option<String>,
5877
}
@@ -92,6 +111,7 @@ impl AuthContext for NativeSQLAuthContext {
92111
impl SqlAuthService for NodeBridgeAuthService {
93112
async fn authenticate(
94113
&self,
114+
request: SqlAuthServiceAuthenticateRequest,
95115
user: Option<String>,
96116
password: Option<String>,
97117
) -> Result<AuthenticateResponse, CubeError> {
@@ -100,9 +120,11 @@ impl SqlAuthService for NodeBridgeAuthService {
100120
let request_id = Uuid::new_v4().to_string();
101121

102122
let extra = serde_json::to_string(&CheckSQLAuthTransportRequest {
103-
request: TransportRequest {
123+
request: TransportAuthRequest {
104124
id: format!("{}-span-1", request_id),
105125
meta: None,
126+
protocol: request.protocol,
127+
method: request.method,
106128
},
107129
user: user.clone(),
108130
password: password.clone(),

packages/cubejs-backend-native/test/sql.test.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,8 @@ describe('SQLInterface', () => {
201201
request: {
202202
id: expect.any(String),
203203
meta: null,
204+
method: expect.any(String),
205+
protocol: expect.any(String),
204206
},
205207
user: user || null,
206208
password:
@@ -258,6 +260,8 @@ describe('SQLInterface', () => {
258260
request: {
259261
id: expect.any(String),
260262
meta: null,
263+
method: expect.any(String),
264+
protocol: expect.any(String),
261265
},
262266
user: 'allowed_user',
263267
password: 'password_for_allowed_user',

packages/cubejs-testing/birdbox-fixtures/postgresql/single/sqlapi.js

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,15 @@ module.exports = {
1313
return query;
1414
},
1515
checkSqlAuth: async (req, user, password) => {
16+
if (!req) {
17+
throw new Error('Request is not defined');
18+
}
19+
20+
const missing = ['protocol', 'method'].filter(key => !(key in req));
21+
if (missing.length) {
22+
throw new Error(`Request object is missing required field(s): ${missing.join(', ')}`);
23+
}
24+
1625
if (user === 'admin') {
1726
if (password && password !== 'admin_password') {
1827
throw new Error(`Password doesn't match for ${user}`);

rust/cubesql/cubesql/src/compile/router.rs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use crate::{
1212
DatabaseVariable, DatabaseVariablesToUpdate,
1313
},
1414
sql::{
15+
auth_service::SqlAuthServiceAuthenticateRequest,
1516
dataframe,
1617
statement::{
1718
ApproximateCountDistinctVisitor, CastReplacer, DateTokenNormalizeReplacer,
@@ -447,12 +448,16 @@ impl QueryRouter {
447448
})?
448449
{
449450
self.state.set_user(Some(to_user.clone()));
451+
let sql_auth_request = SqlAuthServiceAuthenticateRequest {
452+
protocol: "postgres".to_string(),
453+
method: "password".to_string(),
454+
};
450455
let authenticate_response = self
451456
.session_manager
452457
.server
453458
.auth
454459
// TODO do we want to send actual password here?
455-
.authenticate(Some(to_user.clone()), None)
460+
.authenticate(sql_auth_request, Some(to_user.clone()), None)
456461
.await
457462
.map_err(|e| {
458463
CompilationError::internal(format!("Error calling authenticate: {}", e))
@@ -562,11 +567,15 @@ impl QueryRouter {
562567

563568
async fn reauthenticate_if_needed(&self) -> CompilationResult<()> {
564569
if self.state.is_auth_context_expired() {
570+
let sql_auth_request = SqlAuthServiceAuthenticateRequest {
571+
protocol: "postgres".to_string(),
572+
method: "password".to_string(),
573+
};
565574
let authenticate_response = self
566575
.session_manager
567576
.server
568577
.auth
569-
.authenticate(self.state.user(), None)
578+
.authenticate(sql_auth_request, self.state.user(), None)
570579
.await
571580
.map_err(|e| {
572581
CompilationError::fatal(format!(

rust/cubesql/cubesql/src/compile/test/mod.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@ use crate::{
88
},
99
config::{ConfigObj, ConfigObjImpl},
1010
sql::{
11-
compiler_cache::CompilerCacheImpl, dataframe::batches_to_dataframe,
12-
pg_auth_service::PostgresAuthServiceDefaultImpl, AuthContextRef, AuthenticateResponse,
13-
HttpAuthContext, ServerManager, Session, SessionManager, SqlAuthService,
11+
auth_service::SqlAuthServiceAuthenticateRequest, compiler_cache::CompilerCacheImpl,
12+
dataframe::batches_to_dataframe, pg_auth_service::PostgresAuthServiceDefaultImpl,
13+
AuthContextRef, AuthenticateResponse, HttpAuthContext, ServerManager, Session,
14+
SessionManager, SqlAuthService,
1415
},
1516
transport::{
1617
CubeMeta, CubeMetaDimension, CubeMetaJoin, CubeMetaMeasure, CubeMetaSegment,
@@ -747,6 +748,7 @@ pub fn get_test_auth() -> Arc<dyn SqlAuthService> {
747748
impl SqlAuthService for TestSqlAuth {
748749
async fn authenticate(
749750
&self,
751+
_request: SqlAuthServiceAuthenticateRequest,
750752
_user: Option<String>,
751753
password: Option<String>,
752754
) -> Result<AuthenticateResponse, CubeError> {

rust/cubesql/cubesql/src/sql/auth_service.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use std::{any::Any, env, fmt::Debug, sync::Arc};
22

33
use crate::CubeError;
44
use async_trait::async_trait;
5+
use serde::{Deserialize, Serialize};
56
use serde_json::Value;
67

78
// We cannot use generic here. It's why there is this trait
@@ -43,10 +44,17 @@ pub struct AuthenticateResponse {
4344
pub skip_password_check: bool,
4445
}
4546

47+
#[derive(Debug, Clone, Serialize, Deserialize)]
48+
pub struct SqlAuthServiceAuthenticateRequest {
49+
pub protocol: String,
50+
pub method: String,
51+
}
52+
4653
#[async_trait]
4754
pub trait SqlAuthService: Send + Sync + Debug {
4855
async fn authenticate(
4956
&self,
57+
request: SqlAuthServiceAuthenticateRequest,
5058
user: Option<String>,
5159
password: Option<String>,
5260
) -> Result<AuthenticateResponse, CubeError>;
@@ -61,6 +69,7 @@ crate::di_service!(SqlAuthDefaultImpl, [SqlAuthService]);
6169
impl SqlAuthService for SqlAuthDefaultImpl {
6270
async fn authenticate(
6371
&self,
72+
_request: SqlAuthServiceAuthenticateRequest,
6473
_user: Option<String>,
6574
password: Option<String>,
6675
) -> Result<AuthenticateResponse, CubeError> {

rust/cubesql/cubesql/src/sql/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ pub(crate) mod types;
1313
// Public API
1414
pub use auth_service::{
1515
AuthContext, AuthContextRef, AuthenticateResponse, HttpAuthContext, SqlAuthDefaultImpl,
16-
SqlAuthService,
16+
SqlAuthService, SqlAuthServiceAuthenticateRequest,
1717
};
1818
pub use database_variables::postgres::session_vars::CUBESQL_PENALIZE_POST_PROCESSING_VAR;
1919
pub use postgres::*;

rust/cubesql/cubesql/src/sql/postgres/pg_auth_service.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::{collections::HashMap, fmt::Debug, sync::Arc};
33
use async_trait::async_trait;
44

55
use crate::{
6+
sql::auth_service::SqlAuthServiceAuthenticateRequest,
67
sql::{AuthContextRef, SqlAuthService},
78
CubeError,
89
};
@@ -74,8 +75,16 @@ impl PostgresAuthService for PostgresAuthServiceDefaultImpl {
7475
}
7576

7677
let user = parameters.get("user").unwrap().clone();
78+
let sql_auth_request = SqlAuthServiceAuthenticateRequest {
79+
protocol: "postgres".to_string(),
80+
method: "password".to_string(),
81+
};
7782
let authenticate_response = service
78-
.authenticate(Some(user.clone()), Some(password_message.password.clone()))
83+
.authenticate(
84+
sql_auth_request,
85+
Some(user.clone()),
86+
Some(password_message.password.clone()),
87+
)
7988
.await;
8089

8190
let auth_fail = || {

0 commit comments

Comments
 (0)