Skip to content

Commit 53648bb

Browse files
shuyixiongdavidmlwjoyang-nv
authored andcommitted
[TRTLLM-8511][feat] Add update_weights and sleep_wakeup support for rl integration (NVIDIA#8302)
Signed-off-by: shuyix <[email protected]> Co-authored-by: Liwei Ma <[email protected]> Co-authored-by: Jonas Yang CN <[email protected]> Signed-off-by: FredricZ-2007 <[email protected]>
1 parent dea246e commit 53648bb

File tree

23 files changed

+852
-185
lines changed

23 files changed

+852
-185
lines changed

cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,13 @@ class WindowBlockManager
871871
return mIsValidStoreForReuseSequence.at(requestId);
872872
}
873873

874+
void resetReuseState()
875+
{
876+
std::lock_guard<std::mutex> lock(mCachedBlocksRootMutex);
877+
mCachedBlocksRoot
878+
= std::make_shared<KVCacheBlock>(KVCacheBlock::kCachedBlocksRootId, tensorrt_llm::kernels::KVCacheIndex{0});
879+
}
880+
874881
private:
875882
//! \brief Add single block to beam of sequence and mAllocatedBlocksPerSeq.
876883
void addBlockToBeam(BlockPtr& block, GenerationRequest& sequence, SizeType32 beamIdx);
@@ -1347,6 +1354,14 @@ class BlockManager
13471354
return mWindowBlockManagers.at(windowSize).isSequenceValidForStoreForReuse(requestId);
13481355
}
13491356

1357+
void resetReuseState()
1358+
{
1359+
for (auto& [windowSize, manager] : mWindowBlockManagers)
1360+
{
1361+
manager.resetReuseState();
1362+
}
1363+
}
1364+
13501365
private:
13511366
[[nodiscard]] WindowBlockManager const& windowManagerByLayer(SizeType32 layerIdx) const
13521367
{
@@ -1533,6 +1548,7 @@ class BaseKVCacheManager
15331548

15341549
virtual void refreshBlocks() = 0;
15351550
virtual void flushIterationEvents() = 0;
1551+
virtual void resetReuseState() = 0;
15361552

15371553
[[nodiscard]] static SizeType32 getSinkBubbleLength(SizeType32 sinkTokenLen, SizeType32 tokensPerBlock);
15381554

@@ -1913,6 +1929,11 @@ class KVCacheManager : public BaseKVCacheManager
19131929
return mBlockManager.findBlocksInReuseTreeByBlockKey(blockKey, windowSize);
19141930
}
19151931

1932+
void resetReuseState() override
1933+
{
1934+
mBlockManager.resetReuseState();
1935+
}
1936+
19161937
/// @brief Finds the maximum attention window that can be used on a sequence, given some kv-cache block capacity.
19171938
///
19181939
/// @param inputLength The number of input tokens in the sequence.

cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -482,7 +482,8 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m)
482482
.def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents,
483483
nb::call_guard<nb::gil_scoped_release>())
484484
.def("get_last_block_id", &BaseKVCacheManager::getLastBlockId, nb::call_guard<nb::gil_scoped_release>())
485-
.def("unpin_blocks_by_id", &BaseKVCacheManager::unpinBlocksById, nb::call_guard<nb::gil_scoped_release>());
485+
.def("unpin_blocks_by_id", &BaseKVCacheManager::unpinBlocksById, nb::call_guard<nb::gil_scoped_release>())
486+
.def("reset_reuse_state", &BaseKVCacheManager::resetReuseState, nb::call_guard<nb::gil_scoped_release>());
486487

487488
nb::bind_vector<CacheBlockIds>(m, "CacheBlockIds")
488489
.def("__getstate__", [](CacheBlockIds const& v) { return nb::make_tuple(v); })

cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,8 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m)
486486
.def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents,
487487
py::call_guard<py::gil_scoped_release>())
488488
.def("get_last_block_id", &BaseKVCacheManager::getLastBlockId, py::call_guard<py::gil_scoped_release>())
489-
.def("unpin_blocks_by_id", &BaseKVCacheManager::unpinBlocksById, py::call_guard<py::gil_scoped_release>());
489+
.def("unpin_blocks_by_id", &BaseKVCacheManager::unpinBlocksById, py::call_guard<py::gil_scoped_release>())
490+
.def("reset_reuse_state", &BaseKVCacheManager::resetReuseState, py::call_guard<py::gil_scoped_release>());
490491

491492
py::enum_<tbk::CacheType>(m, "CacheType")
492493
.value("SELF", tbk::CacheType::kSELF)
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import torch
2+
3+
from tensorrt_llm._ray_utils import control_action_decorator
4+
from tensorrt_llm._torch.utils import get_device_uuid
5+
from tensorrt_llm.logger import logger
6+
7+
8+
class WorkerExtension:
9+
"""Worker extension class for extending TensorRT-LLM Ray workers with custom functionality.
10+
11+
This class can be injected into tensorrt_llm.LLM() by specifying it via the
12+
ray_worker_extension_cls parameter in LLMArgs when using orchestrator_type='ray'.
13+
The extension methods will be available on each Ray worker and can be called via
14+
the LLM's collective RPC mechanism.
15+
16+
Examples:
17+
Creating an LLM with worker extension:
18+
19+
>>> llm = LLM(
20+
... model=model_dir,
21+
... orchestrator_type="ray",
22+
... ray_worker_extension_cls="rlhf_utils.WorkerExtension",
23+
... )
24+
25+
Calling extension methods via collective RPC:
26+
27+
>>> llm._collective_rpc("update_weights", args=(ipc_handles,))
28+
"""
29+
30+
@control_action_decorator
31+
def update_weights(self, ipc_handles: dict):
32+
"""Update model weights from IPC (Inter-Process Communication) handles.
33+
34+
This method receives shared memory handles from another process (typically FSDP training),
35+
reconstructs tensors from these handles, and loads them into the TensorRT-LLM model.
36+
Uses the control_action_decorator to ensure all active requests are finished before
37+
updating weights.
38+
39+
Args:
40+
ipc_handles: Dictionary mapping device UUIDs to lists of (param_name, tensor_handle) tuples.
41+
Each tensor_handle is a tuple of (func, args) for reconstructing the tensor.
42+
43+
Raises:
44+
ValueError: If the current device's UUID is not found in ipc_handles.
45+
Exception: Re-raises any exception encountered during weight update.
46+
"""
47+
try:
48+
logger.info("Update weights from IPC handles")
49+
device_uuid = get_device_uuid(self.device_id)
50+
51+
if device_uuid not in ipc_handles:
52+
raise ValueError(f"Device UUID {device_uuid} not found in ipc_handles")
53+
54+
weights = {}
55+
all_handles = ipc_handles[device_uuid]
56+
57+
for param_name, tensor_handle in all_handles:
58+
func, args = tensor_handle
59+
list_args = list(args)
60+
list_args[6] = self.device_id # Set target device
61+
tensor = func(*list_args)
62+
weights[param_name] = tensor
63+
64+
self.engine.model_engine.model.load_weights(weights)
65+
torch.cuda.synchronize()
66+
self.engine.reset_prefix_cache()
67+
68+
except Exception as e:
69+
logger.error("Encountered an error in update_weights")
70+
raise e
71+
72+
def check_weights_updated(self):
73+
"""Check if the weights are updated to 0."""
74+
weights_updated = True
75+
for name, p in self.engine.model_engine.model.named_parameters():
76+
weights_updated = weights_updated and torch.allclose(p, torch.zeros_like(p))
77+
return weights_updated

tensorrt_llm/_ray_utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
import functools
1516
from contextlib import contextmanager
17+
from typing import Callable
1618

1719
try:
1820
import ray
@@ -26,3 +28,16 @@ def unwrap_ray_errors():
2628
yield
2729
except ray.exceptions.RayTaskError as e:
2830
raise e.as_instanceof_cause() from e
31+
32+
33+
def control_action_decorator(func: Callable) -> Callable:
34+
"""
35+
Decorator that wraps a method to use control_action context manager.
36+
"""
37+
38+
@functools.wraps(func)
39+
def wrapper(self, *args, **kwargs):
40+
with self.engine.control_action():
41+
return func(self, *args, **kwargs)
42+
43+
return wrapper

tensorrt_llm/_torch/models/modeling_utils.py

Lines changed: 48 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,8 @@ def load_single_module(name, module):
871871
for new_name in params_map[names[-1]]:
872872
fw = filter_weights('.'.join(names[:-1] + [new_name]),
873873
weights)
874+
if not fw:
875+
continue
874876
if new_name in ['k_proj', 'v_proj']:
875877
num_kv_heads_list = [num_kv_heads
876878
] * len(fw) if isinstance(
@@ -887,23 +889,29 @@ def load_single_module(name, module):
887889
}
888890

889891
module_weights.append(fw)
890-
module.load_weights(weights=module_weights)
892+
# Note: module_weights may be empty after filtering (e.g., in streaming weight updates)
893+
if module_weights:
894+
module.load_weights(weights=module_weights)
895+
891896
else:
892897
module_weights = filter_weights(name, weights)
893-
if hasattr(module, 'load_weights'):
894-
module.load_weights(weights=[module_weights])
895-
else:
896-
for n, p in module._parameters.items():
897-
if p is not None:
898+
# Note: module_weights may be empty after filtering (e.g., in streaming weight updates)
899+
if module_weights:
900+
if hasattr(module, 'load_weights'):
901+
module.load_weights(weights=[module_weights])
902+
else:
903+
for n, p in module.named_parameters(recurse=False):
898904
p.data.copy_(module_weights[n][:])
899905

900906
if os.environ.get("TRT_LLM_DISABLE_LOAD_WEIGHTS_IN_PARALLEL",
901907
"True") in ["True", "true", "1", "yes", "y"]:
902-
for name, module in tqdm(list(model.named_modules()),
908+
for name, module in tqdm(list(
909+
model.named_modules(remove_duplicate=False)),
903910
desc="Loading weights"):
904911
load_single_module(name, module)
905912
else:
906-
all_modules = dict(model.named_modules())
913+
# remove_duplicate=False ensures original modules sharing weights with next_layer_layernorm are not skipped
914+
all_modules = dict(model.named_modules(remove_duplicate=False))
907915
serial_load_modules = []
908916
if preload_weight_modules is not None:
909917
for module in preload_weight_modules:
@@ -919,10 +927,13 @@ def load_single_module(name, module):
919927
del all_modules[module]
920928
pbar.close()
921929

922-
pbar = tqdm(list(model.named_modules()),
930+
pbar = tqdm(list(model.named_modules(remove_duplicate=False)),
923931
desc="Loading weights concurrently")
924-
args_list = [(name, module) for name, module in model.named_modules()
925-
if name not in serial_load_modules]
932+
args_list = [
933+
(name, module)
934+
for name, module in model.named_modules(remove_duplicate=False)
935+
if name not in serial_load_modules
936+
]
926937
run_concurrently(load_single_module, args_list, pbar=pbar)
927938

928939

@@ -950,31 +961,36 @@ def load_single_module(name, module):
950961
if weight_mapper.does_require_special_handling(module_name):
951962
module_weights = weight_mapper.apply_callbacks(
952963
module, module_name, module_names_breakdown, weights)
953-
module.load_weights(weights=module_weights)
964+
# Note: module_weights may be empty after filtering (e.g., in streaming weight updates)
965+
if module_weights:
966+
module.load_weights(weights=module_weights)
954967
else:
955968
module_weights = weight_mapper.filter_weights(name, weights)
956-
if weight_mapper.is_special_instance_module(module):
957-
weight_mapper.handle_special_instance_module(
958-
module, module_name, module_weights)
959-
960-
elif hasattr(module, 'load_weights'):
961-
if "linear_attn.conv1d" in name:
962-
module_weights['weight'] = module_weights[
963-
'weight'].squeeze(dim=1)
964-
module.load_weights(weights=[module_weights])
965-
else:
966-
for n, p in module._parameters.items():
967-
if p is not None:
969+
# Note: module_weights may be empty after filtering (e.g., in streaming weight updates)
970+
if module_weights:
971+
if weight_mapper.is_special_instance_module(module):
972+
weight_mapper.handle_special_instance_module(
973+
module, module_name, module_weights)
974+
elif hasattr(module, 'load_weights'):
975+
if module_weights:
976+
if "linear_attn.conv1d" in name:
977+
module_weights['weight'] = module_weights[
978+
'weight'].squeeze(dim=1)
979+
module.load_weights(weights=[module_weights])
980+
else:
981+
for n, p in module.named_parameters(recurse=False):
968982
weight_mapper.handle_manual_copy(
969983
module_name, module_weights, n, p)
970984

971985
if os.environ.get("TRT_LLM_DISABLE_LOAD_WEIGHTS_IN_PARALLEL",
972986
"True") in ["True", "true", "1", "yes", "y"]:
973-
for name, module in tqdm(list(model.named_modules()),
987+
for name, module in tqdm(list(
988+
model.named_modules(remove_duplicate=False)),
974989
desc="Loading weights"):
975990
load_single_module(name, module)
976991
else:
977-
all_modules = dict(model.named_modules())
992+
# remove_duplicate=False ensures original modules sharing weights with next_layer_layernorm are not skipped
993+
all_modules = dict(model.named_modules(remove_duplicate=False))
978994
serial_load_modules = []
979995
if preload_weight_modules is not None:
980996
for module in preload_weight_modules:
@@ -990,8 +1006,11 @@ def load_single_module(name, module):
9901006
del all_modules[module]
9911007
pbar.close()
9921008

993-
pbar = tqdm(list(model.named_modules()),
1009+
pbar = tqdm(list(model.named_modules(remove_duplicate=False)),
9941010
desc="Loading weights concurrently")
995-
args_list = [(name, module) for name, module in model.named_modules()
996-
if name not in serial_load_modules]
1011+
args_list = [
1012+
(name, module)
1013+
for name, module in model.named_modules(remove_duplicate=False)
1014+
if name not in serial_load_modules
1015+
]
9971016
run_concurrently(load_single_module, args_list, pbar=pbar)

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,7 @@ def create_py_executor_instance(
670670
peft_cache_config: Optional[PeftCacheConfig] = None,
671671
scheduler_config: Optional[SchedulerConfig] = None,
672672
cache_transceiver_config: Optional[CacheTransceiverConfig] = None,
673+
virtual_memory_pools: Optional[dict] = None,
673674
) -> PyExecutor:
674675
kv_cache_manager = resources.get(ResourceManagerType.KV_CACHE_MANAGER, None)
675676

@@ -818,7 +819,8 @@ def create_py_executor_instance(
818819
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold,
819820
kv_connector_manager=kv_connector_manager,
820821
max_seq_len=max_seq_len,
821-
peft_cache_config=peft_cache_config)
822+
peft_cache_config=peft_cache_config,
823+
virtual_memory_pools=virtual_memory_pools)
822824

823825

824826
def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int,

tensorrt_llm/_torch/pyexecutor/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,9 @@ class PyTorchConfig:
108108
# If true, ONLY the vision encoder part of the full model is loaded/executed.
109109
mm_encoder_only: bool = False
110110

111+
# Enable extra setup to support sleep feature.
112+
enable_sleep: bool = False
113+
111114
# If true, adjust PyTorch CUDA memory fraction to correspond to the
112115
# total GPU memory minus the statically allocated engine memory.
113116
# If false, set the PyTorch CUDA memory fraction to 1.0.

0 commit comments

Comments
 (0)