Skip to content

Commit cdce47c

Browse files
committed
[TRTLLM-8922][chore] Remove hang detector from py_cache_transceiver_mp test
Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
1 parent 747915e commit cdce47c

File tree

1 file changed

+7
-45
lines changed

1 file changed

+7
-45
lines changed

tests/unittest/disaggregated/test_py_cache_transceiver_mp.py

Lines changed: 7 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import tensorrt_llm.bindings
2020
import tensorrt_llm.bindings.executor as trtllm
2121
from tensorrt_llm import DisaggregatedParams, Mapping, SamplingParams
22-
from tensorrt_llm._torch.pyexecutor.hang_detector import HangDetector
2322
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest, LlmRequestType
2423
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
2524
from tensorrt_llm.bindings import DataType, LlmRequestState
@@ -243,13 +242,6 @@ def signal_handler(signum, frame):
243242
dist.init_process_group(backend="gloo", rank=rank, world_size=world_size)
244243
tensorrt_llm.logger.set_level("info")
245244

246-
def on_hang_detected():
247-
print(f"[Rank {rank}] Hang detected! Forcing exit.", flush=True)
248-
os._exit(1)
249-
250-
hang_detector = HangDetector(timeout=60, on_detected=on_hang_detected)
251-
hang_detector.start()
252-
253245
ctx_instance_num = ctx_tp * ctx_pp
254246
gen_instance_num = gen_tp * gen_pp
255247

@@ -485,7 +477,6 @@ def on_hang_detected():
485477

486478
# Synchronize all processes
487479
dist.barrier()
488-
hang_detector.checkpoint()
489480

490481
# ===== Batch process multiple requests (like C++ cacheTransceiverTest) =====
491482
# Reference: C++ test uses lenList = {30, 10, 60, 80}
@@ -709,33 +700,29 @@ def gather_and_verify_request(
709700
f"handling {len(my_requests)}, {'CTX' if is_ctx else 'GEN'} mode, tp_rank={tp_rank}",
710701
flush=True,
711702
)
712-
hang_detector.checkpoint()
713703

714704
# ===== Phase 2: Transfer =====
715705
if ctx_gen_workflow == "gen_first1":
716-
_run_gen_first1_transfer(rank, is_ctx, transceiver, my_requests, hang_detector.checkpoint)
706+
_run_gen_first1_transfer(rank, is_ctx, transceiver, my_requests)
717707
elif ctx_gen_workflow == "gen_first2":
718-
_run_gen_first2_transfer(rank, is_ctx, transceiver, my_requests, hang_detector.checkpoint)
708+
_run_gen_first2_transfer(rank, is_ctx, transceiver, my_requests)
719709
else:
720710
_run_ctx_first_transfer(
721711
rank, is_ctx, transceiver, my_requests, ctx_enable_dp, gen_enable_dp
722712
)
723-
hang_detector.checkpoint()
724713

725714
# ===== Phase 3: Wait for remaining transfers to complete =====
726715
# Synchronize before checking completion
727-
hang_detector.checkpoint()
716+
728717
dist.barrier()
729-
hang_detector.checkpoint()
730718

731719
if is_ctx and my_requests:
732720
transceiver.check_context_transfer_status(None)
733721
print(f"[Rank {rank}] CTX: All transfers completed ({mode_str})", flush=True)
734-
hang_detector.checkpoint()
722+
735723
elif not is_ctx and my_requests:
736724
transceiver.check_gen_transfer_status(None)
737725
print(f"[Rank {rank}] GEN: All transfers completed ({mode_str})", flush=True)
738-
hang_detector.checkpoint()
739726

740727
if is_gen_first:
741728
# verify the aux data is unpacked correctly
@@ -755,9 +742,8 @@ def gather_and_verify_request(
755742
)
756743

757744
# Synchronize before verification
758-
hang_detector.checkpoint()
745+
759746
dist.barrier()
760-
hang_detector.checkpoint()
761747

762748
# ===== Phase 4: Batch verify all requests =====
763749
# All ranks must participate in gather (collective op), so iterate all_requests.
@@ -814,9 +800,7 @@ def gather_and_verify_request(
814800
dist.broadcast(pass_tensor, src=0)
815801
assert pass_tensor.item() == 1, "Some requests failed verification!"
816802

817-
hang_detector.checkpoint()
818803
dist.barrier()
819-
hang_detector.checkpoint()
820804

821805
# ===== Phase 5: Cleanup requests =====
822806
# All ranks added all requests, so all need to remove them
@@ -827,8 +811,6 @@ def gather_and_verify_request(
827811
if rank == 0:
828812
print(f"[Rank {rank}] Cleanup completed ({mode_str})")
829813

830-
hang_detector.stop()
831-
832814
# Cleanup
833815
dist.destroy_process_group()
834816

@@ -909,13 +891,8 @@ def _wait_ctx_request_ready(transceiver, my_requests):
909891
return all_ready
910892

911893

912-
def _run_gen_first1_transfer(rank, is_ctx, transceiver, my_requests, checkpoint_fn=None):
894+
def _run_gen_first1_transfer(rank, is_ctx, transceiver, my_requests):
913895
"""Generation-first transfer: ctx prepares first, then gen receives and ctx sends."""
914-
915-
def _checkpoint():
916-
if checkpoint_fn:
917-
checkpoint_fn()
918-
919896
# Step 1: Context side calls prepare_context_requests, no kvcache request is sent, thus no request
920897
# can reach CONTEXT_INIT state.
921898
if is_ctx:
@@ -928,9 +905,7 @@ def _checkpoint():
928905
assert req.state == LlmRequestState.DISAGG_CONTEXT_WAIT_SCHEDULER
929906
print(f"[Rank {rank}] CTX: All requests are waiting for being scheduled", flush=True)
930907

931-
_checkpoint()
932908
dist.barrier()
933-
_checkpoint()
934909

935910
# Step 2: Generation side submits receive requests
936911
if not is_ctx:
@@ -944,15 +919,12 @@ def _checkpoint():
944919
f"[Rank {rank}] GEN: Submitted {len(my_requests)} gen-first receive requests",
945920
flush=True,
946921
)
947-
_checkpoint()
948922
dist.barrier()
949-
_checkpoint()
950923

951924
if is_ctx:
952925
# Poll until all requests reach CONTEXT_INIT (peer info arrived)
953926
transceiver.prepare_context_requests(ctx_my_requests)
954927
_wait_ctx_request_ready(transceiver, ctx_my_requests)
955-
_checkpoint()
956928

957929
for req_idx, request in my_requests:
958930
print(
@@ -963,13 +935,8 @@ def _checkpoint():
963935
print(f"[Rank {rank}] CTX: Submitted {len(my_requests)} send requests", flush=True)
964936

965937

966-
def _run_gen_first2_transfer(rank, is_ctx, transceiver, my_requests, checkpoint_fn=None):
938+
def _run_gen_first2_transfer(rank, is_ctx, transceiver, my_requests):
967939
"""Generation-first transfer: gen receives first, then ctx prepares and sends."""
968-
969-
def _checkpoint():
970-
if checkpoint_fn:
971-
checkpoint_fn()
972-
973940
# Step 1: Generation side submits receive requests, now context side doesn't know the requests
974941
# but gets kvcache requests first
975942
if not is_ctx:
@@ -983,9 +950,7 @@ def _checkpoint():
983950
f"[Rank {rank}] GEN: Submitted {len(my_requests)} gen-first receive requests",
984951
flush=True,
985952
)
986-
_checkpoint()
987953
dist.barrier()
988-
_checkpoint()
989954
time.sleep(3) # wait for the receive requests to be submitted
990955
# Step 2: Context side calls prepare_context_requests, now context side knows the requests
991956
# all requests can reach CONTEXT_INIT state directly.
@@ -995,11 +960,8 @@ def _checkpoint():
995960
transceiver.prepare_context_requests(ctx_my_requests)
996961
print(f"[Rank {rank}] CTX: Called prepare_context_requests", flush=True)
997962
_wait_ctx_request_ready(transceiver, ctx_my_requests)
998-
_checkpoint()
999963

1000-
_checkpoint()
1001964
dist.barrier()
1002-
_checkpoint()
1003965
# Step 3: Context side sends the data
1004966
if is_ctx:
1005967
for req_idx, request in my_requests:

0 commit comments

Comments
 (0)