1919import tensorrt_llm .bindings
2020import tensorrt_llm .bindings .executor as trtllm
2121from tensorrt_llm import DisaggregatedParams , Mapping , SamplingParams
22- from tensorrt_llm ._torch .pyexecutor .hang_detector import HangDetector
2322from tensorrt_llm ._torch .pyexecutor .llm_request import LlmRequest , LlmRequestType
2423from tensorrt_llm ._torch .pyexecutor .resource_manager import KVCacheManager
2524from tensorrt_llm .bindings import DataType , LlmRequestState
@@ -243,13 +242,6 @@ def signal_handler(signum, frame):
243242 dist .init_process_group (backend = "gloo" , rank = rank , world_size = world_size )
244243 tensorrt_llm .logger .set_level ("info" )
245244
246- def on_hang_detected ():
247- print (f"[Rank { rank } ] Hang detected! Forcing exit." , flush = True )
248- os ._exit (1 )
249-
250- hang_detector = HangDetector (timeout = 60 , on_detected = on_hang_detected )
251- hang_detector .start ()
252-
253245 ctx_instance_num = ctx_tp * ctx_pp
254246 gen_instance_num = gen_tp * gen_pp
255247
@@ -485,7 +477,6 @@ def on_hang_detected():
485477
486478 # Synchronize all processes
487479 dist .barrier ()
488- hang_detector .checkpoint ()
489480
490481 # ===== Batch process multiple requests (like C++ cacheTransceiverTest) =====
491482 # Reference: C++ test uses lenList = {30, 10, 60, 80}
@@ -709,33 +700,29 @@ def gather_and_verify_request(
709700 f"handling { len (my_requests )} , { 'CTX' if is_ctx else 'GEN' } mode, tp_rank={ tp_rank } " ,
710701 flush = True ,
711702 )
712- hang_detector .checkpoint ()
713703
714704 # ===== Phase 2: Transfer =====
715705 if ctx_gen_workflow == "gen_first1" :
716- _run_gen_first1_transfer (rank , is_ctx , transceiver , my_requests , hang_detector . checkpoint )
706+ _run_gen_first1_transfer (rank , is_ctx , transceiver , my_requests )
717707 elif ctx_gen_workflow == "gen_first2" :
718- _run_gen_first2_transfer (rank , is_ctx , transceiver , my_requests , hang_detector . checkpoint )
708+ _run_gen_first2_transfer (rank , is_ctx , transceiver , my_requests )
719709 else :
720710 _run_ctx_first_transfer (
721711 rank , is_ctx , transceiver , my_requests , ctx_enable_dp , gen_enable_dp
722712 )
723- hang_detector .checkpoint ()
724713
725714 # ===== Phase 3: Wait for remaining transfers to complete =====
726715 # Synchronize before checking completion
727- hang_detector . checkpoint ()
716+
728717 dist .barrier ()
729- hang_detector .checkpoint ()
730718
731719 if is_ctx and my_requests :
732720 transceiver .check_context_transfer_status (None )
733721 print (f"[Rank { rank } ] CTX: All transfers completed ({ mode_str } )" , flush = True )
734- hang_detector . checkpoint ()
722+
735723 elif not is_ctx and my_requests :
736724 transceiver .check_gen_transfer_status (None )
737725 print (f"[Rank { rank } ] GEN: All transfers completed ({ mode_str } )" , flush = True )
738- hang_detector .checkpoint ()
739726
740727 if is_gen_first :
741728 # verify the aux data is unpacked correctly
@@ -755,9 +742,8 @@ def gather_and_verify_request(
755742 )
756743
757744 # Synchronize before verification
758- hang_detector . checkpoint ()
745+
759746 dist .barrier ()
760- hang_detector .checkpoint ()
761747
762748 # ===== Phase 4: Batch verify all requests =====
763749 # All ranks must participate in gather (collective op), so iterate all_requests.
@@ -814,9 +800,7 @@ def gather_and_verify_request(
814800 dist .broadcast (pass_tensor , src = 0 )
815801 assert pass_tensor .item () == 1 , "Some requests failed verification!"
816802
817- hang_detector .checkpoint ()
818803 dist .barrier ()
819- hang_detector .checkpoint ()
820804
821805 # ===== Phase 5: Cleanup requests =====
822806 # All ranks added all requests, so all need to remove them
@@ -827,8 +811,6 @@ def gather_and_verify_request(
827811 if rank == 0 :
828812 print (f"[Rank { rank } ] Cleanup completed ({ mode_str } )" )
829813
830- hang_detector .stop ()
831-
832814 # Cleanup
833815 dist .destroy_process_group ()
834816
@@ -909,13 +891,8 @@ def _wait_ctx_request_ready(transceiver, my_requests):
909891 return all_ready
910892
911893
912- def _run_gen_first1_transfer (rank , is_ctx , transceiver , my_requests , checkpoint_fn = None ):
894+ def _run_gen_first1_transfer (rank , is_ctx , transceiver , my_requests ):
913895 """Generation-first transfer: ctx prepares first, then gen receives and ctx sends."""
914-
915- def _checkpoint ():
916- if checkpoint_fn :
917- checkpoint_fn ()
918-
919896 # Step 1: Context side calls prepare_context_requests, no kvcache request is sent, thus no request
920897 # can reach CONTEXT_INIT state.
921898 if is_ctx :
@@ -928,9 +905,7 @@ def _checkpoint():
928905 assert req .state == LlmRequestState .DISAGG_CONTEXT_WAIT_SCHEDULER
929906 print (f"[Rank { rank } ] CTX: All requests are waiting for being scheduled" , flush = True )
930907
931- _checkpoint ()
932908 dist .barrier ()
933- _checkpoint ()
934909
935910 # Step 2: Generation side submits receive requests
936911 if not is_ctx :
@@ -944,15 +919,12 @@ def _checkpoint():
944919 f"[Rank { rank } ] GEN: Submitted { len (my_requests )} gen-first receive requests" ,
945920 flush = True ,
946921 )
947- _checkpoint ()
948922 dist .barrier ()
949- _checkpoint ()
950923
951924 if is_ctx :
952925 # Poll until all requests reach CONTEXT_INIT (peer info arrived)
953926 transceiver .prepare_context_requests (ctx_my_requests )
954927 _wait_ctx_request_ready (transceiver , ctx_my_requests )
955- _checkpoint ()
956928
957929 for req_idx , request in my_requests :
958930 print (
@@ -963,13 +935,8 @@ def _checkpoint():
963935 print (f"[Rank { rank } ] CTX: Submitted { len (my_requests )} send requests" , flush = True )
964936
965937
966- def _run_gen_first2_transfer (rank , is_ctx , transceiver , my_requests , checkpoint_fn = None ):
938+ def _run_gen_first2_transfer (rank , is_ctx , transceiver , my_requests ):
967939 """Generation-first transfer: gen receives first, then ctx prepares and sends."""
968-
969- def _checkpoint ():
970- if checkpoint_fn :
971- checkpoint_fn ()
972-
973940 # Step 1: Generation side submits receive requests, now context side doesn't know the requests
974941 # but gets kvcache requests first
975942 if not is_ctx :
@@ -983,9 +950,7 @@ def _checkpoint():
983950 f"[Rank { rank } ] GEN: Submitted { len (my_requests )} gen-first receive requests" ,
984951 flush = True ,
985952 )
986- _checkpoint ()
987953 dist .barrier ()
988- _checkpoint ()
989954 time .sleep (3 ) # wait for the receive requests to be submitted
990955 # Step 2: Context side calls prepare_context_requests, now context side knows the requests
991956 # all requests can reach CONTEXT_INIT state directly.
@@ -995,11 +960,8 @@ def _checkpoint():
995960 transceiver .prepare_context_requests (ctx_my_requests )
996961 print (f"[Rank { rank } ] CTX: Called prepare_context_requests" , flush = True )
997962 _wait_ctx_request_ready (transceiver , ctx_my_requests )
998- _checkpoint ()
999963
1000- _checkpoint ()
1001964 dist .barrier ()
1002- _checkpoint ()
1003965 # Step 3: Context side sends the data
1004966 if is_ctx :
1005967 for req_idx , request in my_requests :
0 commit comments