Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 59 additions & 50 deletions numaflow/src/batchmap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,17 +253,16 @@ where

self.perform_handshake(&mut map_stream, &resp_tx).await?;

let (error_tx, error_rx) = channel::<Error>(1);

let grpc_resp_tx = resp_tx.clone();
let handle: JoinHandle<Result<(), Error>> = tokio::spawn(async move {
Self::process_map_stream(map_handle, map_stream, grpc_resp_tx).await
Self::process_map_stream(map_handle, map_stream, grpc_resp_tx, error_tx, cln_token)
.await
});

tokio::spawn(Self::handle_map_errors(
handle,
resp_tx,
shutdown_tx,
cln_token,
));
// TODO: structured concurrency is lost. if `manage_grpc_stream` panics, we will never know
tokio::spawn(manage_grpc_stream(handle, resp_tx, error_rx, shutdown_tx));

Ok(Response::new(ReceiverStream::new(resp_rx)))
}
Expand All @@ -282,15 +281,29 @@ where
map_handle: Arc<T>,
mut map_stream: Streaming<MapRequest>,
grpc_resp_tx: mpsc::Sender<Result<MapResponse, Status>>,
error_tx: mpsc::Sender<Error>,
cln_token: CancellationToken,
) -> Result<(), Error> {
// loop until the global stream has been shutdown.
let mut global_stream_ended = false;
while !global_stream_ended {
loop {
// for every batch, we need to read from the stream. The end-of-batch is
// encoded in the request.
global_stream_ended =
Self::process_map_batch(map_handle.clone(), &mut map_stream, grpc_resp_tx.clone())
.await?;
tokio::select! {
stream_ended = Self::process_map_batch(
map_handle.clone(),
&mut map_stream,
grpc_resp_tx.clone(),
error_tx.clone(),
) => {
if stream_ended? {
break;
}
}
_ = cln_token.cancelled() => {
info!("Cancellation token is cancelled, shutting down");
break;
}
}
}
Ok(())
}
Expand All @@ -303,6 +316,7 @@ where
batch_map_handle: Arc<T>,
map_stream: &mut Streaming<MapRequest>,
grpc_resp_tx: mpsc::Sender<Result<MapResponse, Status>>,
error_tx: mpsc::Sender<Error>,
) -> Result<bool, Error> {
let (tx, rx) = channel::<Datum>(CHANNEL_SIZE);
let resp_tx = grpc_resp_tx.clone();
Expand All @@ -324,11 +338,11 @@ where
if let Some(panic_info) = get_panic_info() {
// This is a panic - send detailed panic information
let status = build_panic_status(&panic_info);
let _ = resp_tx.send(Err(status)).await;
let _ = error_tx.send(Error::GrpcStatus(status)).await;
} else {
// This is a non-panic error
let _ = resp_tx
.send(Err(Status::internal(format!(
let _ = error_tx
.send(Error::BatchMapError(ErrorKind::InternalError(format!(
"Batch-map task execution failed: {e:?}"
))))
.await;
Expand Down Expand Up @@ -423,41 +437,6 @@ where
Ok(global_stream_ended)
}

async fn handle_map_errors(
handle: JoinHandle<Result<(), Error>>,
resp_tx: mpsc::Sender<Result<MapResponse, Status>>,
shutdown_tx: mpsc::Sender<()>,
cln_token: CancellationToken,
) {
tokio::select! {
resp = handle => {
match resp {
Ok(Ok(_)) => {},
Ok(Err(e)) => {
resp_tx
.send(Err(e.into_status()))
.await
.expect("Sending error to response channel");
shutdown_tx.send(()).await.expect("Sending shutdown signal");
}
Err(e) => {
resp_tx
.send(Err(Status::internal(format!("Map handler aborted: {}", e))))
.await
.expect("Sending error to response channel");
shutdown_tx.send(()).await.expect("Sending shutdown signal");
}
}
},
_ = cln_token.cancelled() => {
resp_tx
.send(Err(Status::cancelled("Map handler cancelled")))
.await
.expect("Sending error to response channel");
}
}
}

async fn perform_handshake(
&self,
map_stream: &mut Streaming<MapRequest>,
Expand Down Expand Up @@ -488,6 +467,36 @@ where
}
}

async fn manage_grpc_stream(
request_handler: JoinHandle<Result<(), Error>>,
stream_response_tx: mpsc::Sender<Result<MapResponse, Status>>,
mut error_rx: mpsc::Receiver<Error>,
server_shutdown_tx: mpsc::Sender<()>,
) {
let err = match error_rx.recv().await {
Some(err) => err,
None => match request_handler.await {
Ok(Ok(_)) => return, // normal exit
Ok(Err(e)) => e,
Err(e) => Error::BatchMapError(ErrorKind::InternalError(format!(
"BatchMap request handler aborted: {e:?}"
))),
},
};

error!("Shutting down gRPC channel: {err:?}");
// send error response to the numaflow client
stream_response_tx
.send(Err(err.into_status()))
.await
.expect("Sending error message to gRPC response channel");
// send shutdown signal to the server
server_shutdown_tx
.send(())
.await
.expect("Writing to shutdown channel");
}

/// gRPC server to start a batch map service
#[derive(Debug)]
pub struct Server<T> {
Expand Down
Loading