Skip to content

Commit 3de47ad

Browse files
highkerfacebook-github-bot
authored andcommitted
sync flush of proc mesh
Summary: Provide sync flush so it is guaranteed all the flushed logs on the remote procs will be streamed back and flushed on client's stdout/stderr. Differential Revision: D80051803
1 parent 242c726 commit 3de47ad

File tree

5 files changed

+289
-44
lines changed

5 files changed

+289
-44
lines changed

hyperactor_mesh/src/logging.rs

Lines changed: 189 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use std::fmt;
1111
use std::path::Path;
1212
use std::path::PathBuf;
1313
use std::pin::Pin;
14+
use std::sync::Arc;
1415
use std::task::Context as TaskContext;
1516
use std::task::Poll;
1617
use std::time::Duration;
@@ -22,12 +23,15 @@ use chrono::DateTime;
2223
use chrono::Local;
2324
use hyperactor::Actor;
2425
use hyperactor::ActorRef;
26+
use hyperactor::Bind;
2527
use hyperactor::Context;
2628
use hyperactor::HandleClient;
2729
use hyperactor::Handler;
2830
use hyperactor::Instance;
2931
use hyperactor::Named;
32+
use hyperactor::OncePortRef;
3033
use hyperactor::RefClient;
34+
use hyperactor::Unbind;
3135
use hyperactor::channel;
3236
use hyperactor::channel::ChannelAddr;
3337
use hyperactor::channel::ChannelRx;
@@ -39,14 +43,12 @@ use hyperactor::channel::TxStatus;
3943
use hyperactor::clock::Clock;
4044
use hyperactor::clock::RealClock;
4145
use hyperactor::data::Serialized;
42-
use hyperactor::message::Bind;
43-
use hyperactor::message::Bindings;
44-
use hyperactor::message::Unbind;
4546
use hyperactor_telemetry::env;
4647
use hyperactor_telemetry::log_file_path;
4748
use serde::Deserialize;
4849
use serde::Serialize;
4950
use tokio::io;
51+
use tokio::sync::Mutex;
5052
use tokio::sync::watch::Receiver;
5153

5254
use crate::bootstrap::BOOTSTRAP_LOG_CHANNEL;
@@ -260,7 +262,11 @@ pub enum LogMessage {
260262
},
261263

262264
/// Flush the log
263-
Flush {},
265+
Flush {
266+
/// Indicate if the current flush is synced or non-synced.
267+
/// If synced, a version number is available. Otherwise, none.
268+
sync_version: Option<u64>,
269+
},
264270
}
265271

266272
/// Messages that can be sent to the LogClient locally.
@@ -279,6 +285,16 @@ pub enum LogClientMessage {
279285
/// The time window in seconds to aggregate logs. If None, aggregation is disabled.
280286
aggregate_window_sec: Option<u64>,
281287
},
288+
289+
/// Synchronously flush all the logs from all the procs. This is for client to call.
290+
StartSyncFlush {
291+
/// Expect these many procs to ack the flush message.
292+
expected_procs: usize,
293+
/// Return once we have received the acks from all the procs
294+
reply: OncePortRef<()>,
295+
/// Return to the caller the current flush version
296+
version: OncePortRef<u64>,
297+
},
282298
}
283299

284300
/// Trait for sending logs
@@ -352,7 +368,7 @@ impl LogSender for LocalLogSender {
352368
// send will make sure message is delivered
353369
if TxStatus::Active == *self.status.borrow() {
354370
// Do not use tx.send, it will block the allocator as the child process state is unknown.
355-
self.tx.post(LogMessage::Flush {});
371+
self.tx.post(LogMessage::Flush { sync_version: None });
356372
} else {
357373
tracing::debug!(
358374
"log sender {} is not active, skip sending flush message",
@@ -558,26 +574,19 @@ impl<T: LogSender + Unpin + 'static, S: io::AsyncWrite + Send + Unpin + 'static>
558574
Named,
559575
Handler,
560576
HandleClient,
561-
RefClient
577+
RefClient,
578+
Bind,
579+
Unbind
562580
)]
563581
pub enum LogForwardMessage {
564582
/// Receive the log from the parent process and forward ti to the client.
565583
Forward {},
566584

567585
/// If to stream the log back to the client.
568586
SetMode { stream_to_client: bool },
569-
}
570587

571-
impl Bind for LogForwardMessage {
572-
fn bind(&mut self, _bindings: &mut Bindings) -> anyhow::Result<()> {
573-
Ok(())
574-
}
575-
}
576-
577-
impl Unbind for LogForwardMessage {
578-
fn unbind(&self, _bindings: &mut Bindings) -> anyhow::Result<()> {
579-
Ok(())
580-
}
588+
/// Flush the log with a version number.
589+
ForceSyncFlush { version: u64 },
581590
}
582591

583592
/// A log forwarder that receives the log from its parent process and forward it back to the client
@@ -588,6 +597,8 @@ impl Unbind for LogForwardMessage {
588597
)]
589598
pub struct LogForwardActor {
590599
rx: ChannelRx<LogMessage>,
600+
flush_tx: Arc<Mutex<ChannelTx<LogMessage>>>,
601+
next_flush_deadline: SystemTime,
591602
logging_client_ref: ActorRef<LogClientActor>,
592603
stream_to_client: bool,
593604
}
@@ -630,15 +641,29 @@ impl Actor for LogForwardActor {
630641
.1
631642
}
632643
};
644+
645+
// Dial the same channel to send flush message to drain the log queue.
646+
let flush_tx = Arc::new(Mutex::new(channel::dial::<LogMessage>(log_channel)?));
647+
let now = RealClock.system_time_now();
648+
633649
Ok(Self {
634650
rx,
651+
flush_tx,
652+
next_flush_deadline: now,
635653
logging_client_ref,
636654
stream_to_client: true,
637655
})
638656
}
639657

640658
async fn init(&mut self, this: &Instance<Self>) -> Result<(), anyhow::Error> {
641659
this.self_message_with_delay(LogForwardMessage::Forward {}, Duration::from_secs(0))?;
660+
661+
// Make sure we start the flush loop periodically so the log channel will not deadlock.
662+
self.flush_tx
663+
.lock()
664+
.await
665+
.send(LogMessage::Flush { sync_version: None })
666+
.await?;
642667
Ok(())
643668
}
644669
}
@@ -647,17 +672,48 @@ impl Actor for LogForwardActor {
647672
#[hyperactor::forward(LogForwardMessage)]
648673
impl LogForwardMessageHandler for LogForwardActor {
649674
async fn forward(&mut self, ctx: &Context<Self>) -> Result<(), anyhow::Error> {
650-
if let Ok(LogMessage::Log {
651-
hostname,
652-
pid,
653-
output_target,
654-
payload,
655-
}) = self.rx.recv().await
656-
{
657-
if self.stream_to_client {
658-
self.logging_client_ref
659-
.log(ctx, hostname, pid, output_target, payload)
660-
.await?;
675+
match self.rx.recv().await {
676+
Ok(LogMessage::Flush { sync_version }) => {
677+
let now = RealClock.system_time_now();
678+
match sync_version {
679+
None => {
680+
// Schedule another flush to keep the log channel from deadlocking.
681+
let delay = Duration::from_secs(1);
682+
if now >= self.next_flush_deadline {
683+
self.next_flush_deadline = now + delay;
684+
let flush_tx = self.flush_tx.clone();
685+
tokio::spawn(async move {
686+
RealClock.sleep(delay).await;
687+
if let Err(e) = flush_tx
688+
.lock()
689+
.await
690+
.send(LogMessage::Flush { sync_version: None })
691+
.await
692+
{
693+
tracing::error!("failed to send flush message: {}", e);
694+
}
695+
});
696+
}
697+
}
698+
version => {
699+
self.logging_client_ref.flush(ctx, version).await?;
700+
}
701+
}
702+
}
703+
Ok(LogMessage::Log {
704+
hostname,
705+
pid,
706+
output_target,
707+
payload,
708+
}) => {
709+
if self.stream_to_client {
710+
self.logging_client_ref
711+
.log(ctx, hostname, pid, output_target, payload)
712+
.await?;
713+
}
714+
}
715+
Err(e) => {
716+
return Err(e.into());
661717
}
662718
}
663719

@@ -675,6 +731,21 @@ impl LogForwardMessageHandler for LogForwardActor {
675731
self.stream_to_client = stream_to_client;
676732
Ok(())
677733
}
734+
735+
async fn force_sync_flush(
736+
&mut self,
737+
_cx: &Context<Self>,
738+
version: u64,
739+
) -> Result<(), anyhow::Error> {
740+
self.flush_tx
741+
.lock()
742+
.await
743+
.send(LogMessage::Flush {
744+
sync_version: Some(version),
745+
})
746+
.await
747+
.map_err(anyhow::Error::from)
748+
}
678749
}
679750

680751
/// Deserialize a serialized message and split it into UTF-8 lines
@@ -707,6 +778,11 @@ pub struct LogClientActor {
707778
aggregators: HashMap<OutputTarget, Aggregator>,
708779
last_flush_time: SystemTime,
709780
next_flush_deadline: Option<SystemTime>,
781+
782+
// For flush sync barrier
783+
current_flush_version: u64,
784+
current_flush_port: Option<OncePortRef<()>>,
785+
current_unflushed_procs: usize,
710786
}
711787

712788
impl LogClientActor {
@@ -736,6 +812,12 @@ impl LogClientActor {
736812
OutputTarget::Stderr => eprintln!("{}", message),
737813
}
738814
}
815+
816+
fn flush_internal(&mut self) {
817+
self.print_aggregators();
818+
self.last_flush_time = RealClock.system_time_now();
819+
self.next_flush_deadline = None;
820+
}
739821
}
740822

741823
#[async_trait]
@@ -754,6 +836,9 @@ impl Actor for LogClientActor {
754836
aggregators,
755837
last_flush_time: RealClock.system_time_now(),
756838
next_flush_deadline: None,
839+
current_flush_version: 0,
840+
current_flush_port: None,
841+
current_unflushed_procs: 0,
757842
})
758843
}
759844
}
@@ -805,20 +890,26 @@ impl LogMessageHandler for LogClientActor {
805890
let new_deadline = self.last_flush_time + Duration::from_secs(window);
806891
let now = RealClock.system_time_now();
807892
if new_deadline <= now {
808-
self.flush(cx).await?;
893+
self.flush_internal();
809894
} else {
810895
let delay = new_deadline.duration_since(now)?;
811896
match self.next_flush_deadline {
812897
None => {
813898
self.next_flush_deadline = Some(new_deadline);
814-
cx.self_message_with_delay(LogMessage::Flush {}, delay)?;
899+
cx.self_message_with_delay(
900+
LogMessage::Flush { sync_version: None },
901+
delay,
902+
)?;
815903
}
816904
Some(deadline) => {
817905
// Some early log lines have alrady triggered the flush.
818906
if new_deadline < deadline {
819907
// This can happen if the user has adjusted the aggregation window.
820908
self.next_flush_deadline = Some(new_deadline);
821-
cx.self_message_with_delay(LogMessage::Flush {}, delay)?;
909+
cx.self_message_with_delay(
910+
LogMessage::Flush { sync_version: None },
911+
delay,
912+
)?;
822913
}
823914
}
824915
}
@@ -829,10 +920,45 @@ impl LogMessageHandler for LogClientActor {
829920
Ok(())
830921
}
831922

832-
async fn flush(&mut self, _cx: &Context<Self>) -> Result<(), anyhow::Error> {
833-
self.print_aggregators();
834-
self.last_flush_time = RealClock.system_time_now();
835-
self.next_flush_deadline = None;
923+
async fn flush(
924+
&mut self,
925+
cx: &Context<Self>,
926+
sync_version: Option<u64>,
927+
) -> Result<(), anyhow::Error> {
928+
match sync_version {
929+
None => {
930+
self.flush_internal();
931+
}
932+
Some(version) => {
933+
if version != self.current_flush_version {
934+
tracing::error!(
935+
"found mismatched flush versions: got {}, expect {}; this can happen if some previous flush didn't finish fully",
936+
version,
937+
self.current_flush_version
938+
);
939+
return Ok(());
940+
}
941+
942+
if self.current_unflushed_procs == 0 || self.current_flush_port.is_none() {
943+
// This is a serious issue; it's better to error out.
944+
anyhow::bail!("found no ongoing flush request");
945+
}
946+
self.current_unflushed_procs -= 1;
947+
948+
tracing::debug!(
949+
"ack sync flush: version {}; remaining procs: {}",
950+
self.current_flush_version,
951+
self.current_unflushed_procs
952+
);
953+
954+
if self.current_unflushed_procs == 0 {
955+
self.flush_internal();
956+
let reply = self.current_flush_port.take().unwrap();
957+
self.current_flush_port = None;
958+
reply.send(cx, ()).map_err(anyhow::Error::from)?;
959+
}
960+
}
961+
}
836962

837963
Ok(())
838964
}
@@ -853,6 +979,34 @@ impl LogClientMessageHandler for LogClientActor {
853979
self.aggregate_window_sec = aggregate_window_sec;
854980
Ok(())
855981
}
982+
983+
async fn start_sync_flush(
984+
&mut self,
985+
cx: &Context<Self>,
986+
expected_procs_flushed: usize,
987+
reply: OncePortRef<()>,
988+
version: OncePortRef<u64>,
989+
) -> Result<(), anyhow::Error> {
990+
if self.current_unflushed_procs > 0 || self.current_flush_port.is_some() {
991+
tracing::warn!(
992+
"found unfinished ongoing flush: version {}; {} unflushed procs",
993+
self.current_flush_version,
994+
self.current_unflushed_procs,
995+
);
996+
}
997+
998+
self.current_flush_version += 1;
999+
tracing::debug!(
1000+
"start sync flush with version {}",
1001+
self.current_flush_version
1002+
);
1003+
self.current_flush_port = Some(reply.clone());
1004+
self.current_unflushed_procs = expected_procs_flushed;
1005+
version
1006+
.send(cx, self.current_flush_version)
1007+
.map_err(anyhow::Error::from)?;
1008+
Ok(())
1009+
}
8561010
}
8571011

8581012
#[cfg(test)]

0 commit comments

Comments
 (0)