Skip to content

Commit 525f2b2

Browse files
committed
refactor(backend-native): Extract with_session function
1 parent 8e6fe39 commit 525f2b2

File tree

2 files changed

+157
-138
lines changed

2 files changed

+157
-138
lines changed

packages/cubejs-backend-native/src/cubesql_utils.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use std::future::Future;
12
use std::net::SocketAddr;
23
use std::str::FromStr;
34
use std::sync::Arc;
@@ -49,3 +50,27 @@ pub async fn create_session(
4950

5051
Ok(session)
5152
}
53+
54+
pub async fn with_session<T, F, Fut>(
55+
services: &NodeCubeServices,
56+
native_auth_ctx: Arc<NativeAuthContext>,
57+
f: F,
58+
) -> Result<T, CubeError>
59+
where
60+
F: FnOnce(Arc<Session>) -> Fut,
61+
Fut: Future<Output = Result<T, CubeError>>,
62+
{
63+
let session_manager = services
64+
.injector()
65+
.get_service_typed::<SessionManager>()
66+
.await;
67+
let session = create_session(services, native_auth_ctx).await?;
68+
let connection_id = session.state.connection_id;
69+
70+
// From now there's a session we should close before returning, as in `finally`
71+
let result = { f(session).await };
72+
73+
session_manager.drop_session(connection_id).await;
74+
75+
result
76+
}

packages/cubejs-backend-native/src/node_export.rs

Lines changed: 132 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
use cubesql::compile::{convert_sql_to_cube_query, get_df_batches};
22
use cubesql::config::processing_loop::ShutdownMode;
3-
use cubesql::sql::SessionManager;
43
use cubesql::transport::TransportService;
54
use futures::StreamExt;
65

@@ -11,7 +10,7 @@ use crate::auth::{NativeAuthContext, NodeBridgeAuthService};
1110
use crate::channel::call_js_fn;
1211
use crate::config::{NodeConfiguration, NodeConfigurationFactoryOptions, NodeCubeServices};
1312
use crate::cross::CLRepr;
14-
use crate::cubesql_utils::create_session;
13+
use crate::cubesql_utils::with_session;
1514
use crate::logger::NodeBridgeLogger;
1615
use crate::stream::OnDrainHandler;
1716
use crate::tokio_runtime_node;
@@ -185,77 +184,98 @@ async fn handle_sql_query(
185184
) -> Result<(), CubeError> {
186185
let start_time = SystemTime::now();
187186

188-
let session = create_session(&services, native_auth_ctx.clone()).await?;
189-
190-
if let Some(auth_context) = session.state.auth_context() {
191-
session
192-
.session_manager
193-
.server
194-
.transport
195-
.log_load_state(
196-
None,
197-
auth_context,
198-
session.state.get_load_request_meta("sql"),
199-
"Load Request".to_string(),
200-
serde_json::json!({
201-
"query": {
202-
"sql": sql_query,
203-
}
204-
}),
205-
)
206-
.await?;
207-
}
208-
209187
let transport_service = services
210188
.injector()
211189
.get_service_typed::<dyn TransportService>()
212190
.await;
213-
let session_manager = services
214-
.injector()
215-
.get_service_typed::<SessionManager>()
216-
.await;
217191

218-
let session_clone = Arc::clone(&session);
192+
with_session(&services, native_auth_ctx.clone(), |session| async move {
193+
if let Some(auth_context) = session.state.auth_context() {
194+
session
195+
.session_manager
196+
.server
197+
.transport
198+
.log_load_state(
199+
None,
200+
auth_context,
201+
session.state.get_load_request_meta("sql"),
202+
"Load Request".to_string(),
203+
serde_json::json!({
204+
"query": {
205+
"sql": sql_query,
206+
}
207+
}),
208+
)
209+
.await?;
210+
}
219211

220-
let execute = || async move {
221-
// todo: can we use compiler_cache?
222-
let meta_context = transport_service
223-
.meta(native_auth_ctx)
224-
.await
225-
.map_err(|err| CubeError::internal(format!("Failed to get meta context: {}", err)))?;
226-
let query_plan = convert_sql_to_cube_query(sql_query, meta_context, session).await?;
212+
let session_clone = Arc::clone(&session);
227213

228-
let mut stream = get_df_batches(&query_plan).await?;
214+
let execute = || async move {
215+
// todo: can we use compiler_cache?
216+
let meta_context = transport_service
217+
.meta(native_auth_ctx)
218+
.await
219+
.map_err(|err| {
220+
CubeError::internal(format!("Failed to get meta context: {}", err))
221+
})?;
222+
let query_plan = convert_sql_to_cube_query(sql_query, meta_context, session).await?;
229223

230-
let semaphore = Arc::new(Semaphore::new(0));
224+
let mut stream = get_df_batches(&query_plan).await?;
231225

232-
let drain_handler = OnDrainHandler::new(
233-
channel.clone(),
234-
stream_methods.stream.clone(),
235-
semaphore.clone(),
236-
);
226+
let semaphore = Arc::new(Semaphore::new(0));
227+
228+
let drain_handler = OnDrainHandler::new(
229+
channel.clone(),
230+
stream_methods.stream.clone(),
231+
semaphore.clone(),
232+
);
233+
234+
drain_handler.handle(stream_methods.on.clone()).await?;
237235

238-
drain_handler.handle(stream_methods.on.clone()).await?;
236+
let mut is_first_batch = true;
237+
while let Some(batch) = stream.next().await {
238+
let (columns, data) = batch_to_rows(batch?)?;
239239

240-
let mut is_first_batch = true;
241-
while let Some(batch) = stream.next().await {
242-
let (columns, data) = batch_to_rows(batch?)?;
240+
if is_first_batch {
241+
let mut schema = Map::new();
242+
schema.insert("schema".into(), columns);
243+
let columns = format!(
244+
"{}{}",
245+
serde_json::to_string(&serde_json::Value::Object(schema))?,
246+
CHUNK_DELIM
247+
);
248+
is_first_batch = false;
249+
250+
call_js_fn(
251+
channel.clone(),
252+
stream_methods.write.clone(),
253+
Box::new(|cx| {
254+
let arg = cx.string(columns).upcast::<JsValue>();
255+
256+
Ok(vec![arg.upcast::<JsValue>()])
257+
}),
258+
Box::new(|cx, v| match v.downcast_or_throw::<JsBoolean, _>(cx) {
259+
Ok(v) => Ok(v.value(cx)),
260+
Err(_) => Err(CubeError::internal(
261+
"Failed to downcast write response".to_string(),
262+
)),
263+
}),
264+
stream_methods.stream.clone(),
265+
)
266+
.await?;
267+
}
243268

244-
if is_first_batch {
245-
let mut schema = Map::new();
246-
schema.insert("schema".into(), columns);
247-
let columns = format!(
248-
"{}{}",
249-
serde_json::to_string(&serde_json::Value::Object(schema))?,
250-
CHUNK_DELIM
251-
);
252-
is_first_batch = false;
269+
let mut rows = Map::new();
270+
rows.insert("data".into(), serde_json::Value::Array(data));
271+
let data = format!("{}{}", serde_json::to_string(&rows)?, CHUNK_DELIM);
272+
let js_stream_write_fn = stream_methods.write.clone();
253273

254-
call_js_fn(
274+
let should_pause = !call_js_fn(
255275
channel.clone(),
256-
stream_methods.write.clone(),
276+
js_stream_write_fn,
257277
Box::new(|cx| {
258-
let arg = cx.string(columns).upcast::<JsValue>();
278+
let arg = cx.string(data).upcast::<JsValue>();
259279

260280
Ok(vec![arg.upcast::<JsValue>()])
261281
}),
@@ -268,93 +288,67 @@ async fn handle_sql_query(
268288
stream_methods.stream.clone(),
269289
)
270290
.await?;
271-
}
272-
273-
let mut rows = Map::new();
274-
rows.insert("data".into(), serde_json::Value::Array(data));
275-
let data = format!("{}{}", serde_json::to_string(&rows)?, CHUNK_DELIM);
276-
let js_stream_write_fn = stream_methods.write.clone();
277291

278-
let should_pause = !call_js_fn(
279-
channel.clone(),
280-
js_stream_write_fn,
281-
Box::new(|cx| {
282-
let arg = cx.string(data).upcast::<JsValue>();
283-
284-
Ok(vec![arg.upcast::<JsValue>()])
285-
}),
286-
Box::new(|cx, v| match v.downcast_or_throw::<JsBoolean, _>(cx) {
287-
Ok(v) => Ok(v.value(cx)),
288-
Err(_) => Err(CubeError::internal(
289-
"Failed to downcast write response".to_string(),
290-
)),
291-
}),
292-
stream_methods.stream.clone(),
293-
)
294-
.await?;
295-
296-
if should_pause {
297-
let permit = semaphore.acquire().await?;
298-
permit.forget();
292+
if should_pause {
293+
let permit = semaphore.acquire().await?;
294+
permit.forget();
295+
}
299296
}
300-
}
301297

302-
Ok::<(), CubeError>(())
303-
};
298+
Ok::<(), CubeError>(())
299+
};
304300

305-
let result = execute().await;
306-
let duration = start_time.elapsed().unwrap().as_millis() as u64;
301+
let result = execute().await;
302+
let duration = start_time.elapsed().unwrap().as_millis() as u64;
307303

308-
match &result {
309-
Ok(_) => {
310-
session_clone
311-
.session_manager
312-
.server
313-
.transport
314-
.log_load_state(
315-
None,
316-
session_clone.state.auth_context().unwrap(),
317-
session_clone.state.get_load_request_meta("sql"),
318-
"Load Request Success".to_string(),
319-
serde_json::json!({
320-
"query": {
321-
"sql": sql_query,
322-
},
323-
"apiType": "sql",
324-
"duration": duration,
325-
"isDataQuery": true
326-
}),
327-
)
328-
.await?;
329-
}
330-
Err(err) => {
331-
session_clone
332-
.session_manager
333-
.server
334-
.transport
335-
.log_load_state(
336-
None,
337-
session_clone.state.auth_context().unwrap(),
338-
session_clone.state.get_load_request_meta("sql"),
339-
"Cube SQL Error".to_string(),
340-
serde_json::json!({
341-
"query": {
342-
"sql": sql_query
343-
},
344-
"apiType": "sql",
345-
"duration": duration,
346-
"error": err.message,
347-
}),
348-
)
349-
.await?;
304+
match &result {
305+
Ok(_) => {
306+
session_clone
307+
.session_manager
308+
.server
309+
.transport
310+
.log_load_state(
311+
None,
312+
session_clone.state.auth_context().unwrap(),
313+
session_clone.state.get_load_request_meta("sql"),
314+
"Load Request Success".to_string(),
315+
serde_json::json!({
316+
"query": {
317+
"sql": sql_query,
318+
},
319+
"apiType": "sql",
320+
"duration": duration,
321+
"isDataQuery": true
322+
}),
323+
)
324+
.await?;
325+
}
326+
Err(err) => {
327+
session_clone
328+
.session_manager
329+
.server
330+
.transport
331+
.log_load_state(
332+
None,
333+
session_clone.state.auth_context().unwrap(),
334+
session_clone.state.get_load_request_meta("sql"),
335+
"Cube SQL Error".to_string(),
336+
serde_json::json!({
337+
"query": {
338+
"sql": sql_query
339+
},
340+
"apiType": "sql",
341+
"duration": duration,
342+
"error": err.message,
343+
}),
344+
)
345+
.await?;
346+
}
350347
}
351-
}
352-
353-
session_manager
354-
.drop_session(session_clone.state.connection_id)
355-
.await;
356348

357-
result
349+
result
350+
})
351+
.await
358352
}
359353

360354
struct WritableStreamMethods {

0 commit comments

Comments
 (0)