Skip to content

Commit 20c01df

Browse files
Add do get statement handler and DRY (#324)
1 parent 35e5259 commit 20c01df

File tree

1 file changed

+125
-89
lines changed

1 file changed

+125
-89
lines changed

src/server/flightsql/service.rs

Lines changed: 125 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,88 @@ impl FlightSqlServiceImpl {
5353
}
5454
}
5555

56-
/// Return an [`FlightServiceServer`] that can be used with a
56+
/// Return a [`FlightServiceServer`] that can be used with a
5757
/// [`Server`](tonic::transport::Server)
5858
pub fn service(&self) -> FlightServiceServer<Self> {
5959
// wrap up tonic goop
6060
FlightServiceServer::new(self.clone())
6161
}
6262

63+
async fn do_get_common_handler(
64+
&self,
65+
request_id: String,
66+
) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
67+
match Uuid::from_str(&request_id) {
68+
Ok(id) => {
69+
info!("getting plan for id: {:?}", id);
70+
// Limit the scope of the lock
71+
let maybe_plan = {
72+
let guard = self
73+
.requests
74+
.lock()
75+
.map_err(|_| Status::internal("Failed to acquire lock on requests"))?;
76+
guard.get(&id).cloned()
77+
};
78+
if let Some(plan) = maybe_plan {
79+
let stream = self
80+
.execution
81+
.execute_logical_plan(plan)
82+
.await
83+
.map_err(|e| Status::internal(e.to_string()))?;
84+
let builder = FlightDataEncoderBuilder::new();
85+
let flight_data_stream = builder
86+
.build(stream.map_err(|e| FlightError::ExternalError(Box::new(e))))
87+
.map_err(|e| Status::internal(e.to_string()))
88+
.boxed();
89+
Ok(Response::new(flight_data_stream))
90+
} else {
91+
Err(Status::internal("Plan not found for id"))
92+
}
93+
}
94+
Err(e) => {
95+
error!("error decoding handle to uuid for {request_id}: {:?}", e);
96+
Err(Status::internal(
97+
"Error decoding handle to uuid for {request_id}",
98+
))
99+
}
100+
}
101+
}
102+
103+
async fn record_request(
104+
&self,
105+
start: Timestamp,
106+
request_id: Option<String>,
107+
response_err: Option<&Status>,
108+
path: String,
109+
latency_metric: &'static str,
110+
) {
111+
let duration = Timestamp::now() - start;
112+
let grpc_code = match &response_err {
113+
None => Code::Ok,
114+
Some(status) => status.code(),
115+
};
116+
let ctx = self.execution.session_ctx();
117+
let req = ObservabilityRequestDetails {
118+
request_id,
119+
path,
120+
sql: None,
121+
rows: None,
122+
start_ms: start.as_millisecond(),
123+
duration_ms: duration.get_milliseconds(),
124+
status: grpc_code as u16,
125+
};
126+
if let Err(e) = self
127+
.execution
128+
.observability()
129+
.try_record_request(ctx, req)
130+
.await
131+
{
132+
error!("Error recording request: {}", e.to_string())
133+
}
134+
135+
histogram!(latency_metric).record(duration.get_milliseconds() as f64);
136+
}
137+
63138
async fn get_flight_info_statement_handler(
64139
&self,
65140
query: String,
@@ -120,47 +195,22 @@ impl FlightSqlServiceImpl {
120195
}
121196
}
122197

198+
async fn do_get_statement_handler(
199+
&self,
200+
request_id: String,
201+
ticket: TicketStatementQuery,
202+
) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
203+
debug!("do_get_statement ticket: {:?}", ticket);
204+
self.do_get_common_handler(request_id).await
205+
}
206+
123207
async fn do_get_fallback_handler(
124208
&self,
125209
request_id: String,
126210
message: Any,
127211
) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
128212
debug!("do_get_fallback message: {:?}", message);
129-
130-
match Uuid::from_str(&request_id) {
131-
Ok(id) => {
132-
info!("getting plan for id: {:?}", id);
133-
// Limit the scope of the lock
134-
let maybe_plan = {
135-
let guard = self
136-
.requests
137-
.lock()
138-
.map_err(|_| Status::internal("Failed to acquire lock on requests"))?;
139-
guard.get(&id).cloned()
140-
};
141-
if let Some(plan) = maybe_plan {
142-
let stream = self
143-
.execution
144-
.execute_logical_plan(plan)
145-
.await
146-
.map_err(|e| Status::internal(e.to_string()))?;
147-
let builder = FlightDataEncoderBuilder::new();
148-
let flight_data_stream = builder
149-
.build(stream.map_err(|e| FlightError::ExternalError(Box::new(e))))
150-
.map_err(|e| Status::internal(e.to_string()))
151-
.boxed();
152-
Ok(Response::new(flight_data_stream))
153-
} else {
154-
Err(Status::internal("Plan not found for id"))
155-
}
156-
}
157-
Err(e) => {
158-
error!("error decoding handle to uuid for {request_id}: {:?}", e);
159-
Err(Status::internal(
160-
"Error decoding handle to uuid for {request_id}",
161-
))
162-
}
163-
}
213+
self.do_get_common_handler(request_id).await
164214
}
165215
}
166216

@@ -180,41 +230,43 @@ impl FlightSqlService for FlightSqlServiceImpl {
180230
let res = self
181231
.get_flight_info_statement_handler(query.clone(), request_id, request)
182232
.await;
183-
let duration = Timestamp::now() - start;
184-
185-
let grpc_code = match &res {
186-
Ok(_) => Code::Ok,
187-
Err(status) => status.code(),
188-
};
189233

190-
let ctx = self.execution.session_ctx();
191-
let req = ObservabilityRequestDetails {
192-
request_id: Some(request_id.to_string()),
193-
path: "GetFlightInfo".to_string(),
194-
sql: Some(query),
195-
rows: None,
196-
start_ms: start.as_millisecond(),
197-
duration_ms: duration.get_milliseconds(),
198-
status: grpc_code as u16,
199-
};
200-
if let Err(e) = self
201-
.execution
202-
.observability()
203-
.try_record_request(ctx, req)
204-
.await
205-
{
206-
error!("Error recording request: {}", e.to_string())
207-
}
208-
histogram!("get_flight_info_latency_ms").record(duration.get_milliseconds() as f64);
234+
// TODO: Move recording to after response is sent to not impact response latency
235+
self.record_request(
236+
start,
237+
Some(request_id.to_string()),
238+
res.as_ref().err(),
239+
"/get_flight_info_statement".to_string(),
240+
"get_flight_info_statement_latency_ms",
241+
)
242+
.await;
209243
res
210244
}
211245

212246
async fn do_get_statement(
213247
&self,
214-
_ticket: TicketStatementQuery,
215-
_request: Request<Ticket>,
248+
ticket: TicketStatementQuery,
249+
request: Request<Ticket>,
216250
) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
217-
Err(Status::unimplemented("Not implemented"))
251+
counter!("requests", "endpoint" => "do_get_statement").increment(1);
252+
let start = Timestamp::now();
253+
let request_id =
254+
try_request_id_from_request(request).map_err(|e| Status::internal(e.to_string()))?;
255+
debug!("do_get_statement for request_id: {}", &request_id);
256+
let res = self
257+
.do_get_statement_handler(request_id.clone(), ticket)
258+
.await;
259+
260+
// TODO: Move recording to after response is sent to not impact response latency
261+
self.record_request(
262+
start,
263+
Some(request_id),
264+
res.as_ref().err(),
265+
"/do_get_statement".to_string(),
266+
"do_get_statement_latency_ms",
267+
)
268+
.await;
269+
res
218270
}
219271

220272
async fn do_get_fallback(
@@ -231,31 +283,15 @@ impl FlightSqlService for FlightSqlServiceImpl {
231283
.do_get_fallback_handler(request_id.clone(), message)
232284
.await;
233285

234-
let duration = Timestamp::now() - start;
235-
let grpc_code = match &res {
236-
Ok(_) => Code::Ok,
237-
Err(status) => status.code(),
238-
};
239-
let ctx = self.execution.session_ctx();
240-
let req = ObservabilityRequestDetails {
241-
request_id: Some(request_id),
242-
path: "DoGetFallback".to_string(),
243-
sql: None,
244-
rows: None,
245-
start_ms: start.as_millisecond(),
246-
duration_ms: duration.get_milliseconds(),
247-
status: grpc_code as u16,
248-
};
249-
if let Err(e) = self
250-
.execution
251-
.observability()
252-
.try_record_request(ctx, req)
253-
.await
254-
{
255-
error!("Error recording request: {}", e.to_string())
256-
}
257-
258-
histogram!("do_get_fallback_latency_ms").record(duration.get_milliseconds() as f64);
286+
// TODO: Move recording to after response is sent to not impact response latency
287+
self.record_request(
288+
start,
289+
Some(request_id),
290+
res.as_ref().err(),
291+
"/do_get_fallback".to_string(),
292+
"do_get_fallback_latency_ms",
293+
)
294+
.await;
259295
res
260296
}
261297

0 commit comments

Comments
 (0)