Skip to content

Commit f474be2

Browse files
authored
Fix hanging caused by tqdm stderr not being printed (#352)
1 parent ae7aa14 commit f474be2

File tree

2 files changed

+35
-22
lines changed

2 files changed

+35
-22
lines changed

launcher/src/main.rs

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,9 @@ fn shard_manager(
497497
// Safetensors load fast
498498
envs.push(("SAFETENSORS_FAST_GPU".into(), "1".into()));
499499

500+
// Disable progress bars to prevent hanging in containers
501+
envs.push(("HF_HUB_DISABLE_PROGRESS_BARS".into(), "1".into()));
502+
500503
// Enable hf transfer for insane download speeds
501504
let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string());
502505
envs.push((
@@ -564,13 +567,20 @@ fn shard_manager(
564567
}
565568
};
566569

567-
// Redirect STDOUT to the console
568-
let shard_stdout_reader = BufReader::new(p.stdout.take().unwrap());
569-
let shard_stderr_reader = BufReader::new(p.stderr.take().unwrap());
570+
let shard_stdout = BufReader::new(p.stdout.take().unwrap());
571+
572+
thread::spawn(move || {
573+
log_lines(shard_stdout.lines());
574+
});
575+
576+
let shard_stderr = BufReader::new(p.stderr.take().unwrap());
570577

571-
//stdout tracing thread
578+
// We read stderr in another thread as it seems that lines() can block in some cases
579+
let (err_sender, err_receiver) = mpsc::channel();
572580
thread::spawn(move || {
573-
log_lines(shard_stdout_reader.lines());
581+
for line in shard_stderr.lines().flatten() {
582+
err_sender.send(line).unwrap_or(());
583+
}
574584
});
575585

576586
let mut ready = false;
@@ -579,13 +589,6 @@ fn shard_manager(
579589
loop {
580590
// Process exited
581591
if let Some(exit_status) = p.try_wait().unwrap() {
582-
// We read stderr in another thread as it seems that lines() can block in some cases
583-
let (err_sender, err_receiver) = mpsc::channel();
584-
thread::spawn(move || {
585-
for line in shard_stderr_reader.lines().flatten() {
586-
err_sender.send(line).unwrap_or(());
587-
}
588-
});
589592
let mut err = String::new();
590593
while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) {
591594
err = err + "\n" + &line;
@@ -796,6 +799,9 @@ fn download_convert_model(
796799
envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
797800
};
798801

802+
// Disable progress bars to prevent hanging in containers
803+
envs.push(("HF_HUB_DISABLE_PROGRESS_BARS".into(), "1".into()));
804+
799805
// Enable hf transfer for insane download speeds
800806
let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string());
801807
envs.push((
@@ -840,12 +846,20 @@ fn download_convert_model(
840846
}
841847
};
842848

843-
// Redirect STDOUT to the console
844-
let download_stdout = download_process.stdout.take().unwrap();
845-
let stdout = BufReader::new(download_stdout);
849+
let download_stdout = BufReader::new(download_process.stdout.take().unwrap());
846850

847851
thread::spawn(move || {
848-
log_lines(stdout.lines());
852+
log_lines(download_stdout.lines());
853+
});
854+
855+
let download_stderr = BufReader::new(download_process.stderr.take().unwrap());
856+
857+
// We read stderr in another thread as it seems that lines() can block in some cases
858+
let (err_sender, err_receiver) = mpsc::channel();
859+
thread::spawn(move || {
860+
for line in download_stderr.lines().flatten() {
861+
err_sender.send(line).unwrap_or(());
862+
}
849863
});
850864

851865
loop {
@@ -856,12 +870,9 @@ fn download_convert_model(
856870
}
857871

858872
let mut err = String::new();
859-
download_process
860-
.stderr
861-
.take()
862-
.unwrap()
863-
.read_to_string(&mut err)
864-
.unwrap();
873+
while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) {
874+
err = err + "\n" + &line;
875+
}
865876
if let Some(signal) = status.signal() {
866877
tracing::error!(
867878
"Download process was signaled to shutdown with signal {signal}: {err}"

server/lorax_server/models/flash_causal_lm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -784,10 +784,12 @@ def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int):
784784
)
785785

786786
with warmup_mode():
787+
logger.info("Warming up to max_total_tokens: {}", max_new_tokens)
787788
with tqdm(total=max_new_tokens, desc="Warmup to max_total_tokens") as pbar:
788789
for _ in range(max_new_tokens):
789790
_, batch = self.generate_token(batch, is_warmup=True)
790791
pbar.update(1)
792+
logger.info("Finished generating warmup tokens")
791793
except RuntimeError as e:
792794
if "CUDA out of memory" in str(e) or isinstance(e, torch.cuda.OutOfMemoryError):
793795
raise RuntimeError(

0 commit comments

Comments
 (0)