Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,12 @@ class WindowBlockManager
return 0;
}

void resetReuseState()
{
mContextBlocksByHash.clear();
mCachedBlocksRoot = std::make_shared<KVCacheBlock>(KVCacheBlock::kCachedBlocksRootId, tensorrt_llm::kernels::KVCacheIndex{0});
}

private:
//! \brief Add single block to beam of sequence and mAllocatedBlocksPerSeq.
void addBlockToBeam(BlockPtr& block, GenerationRequest& sequence, SizeType32 beamIdx);
Expand Down Expand Up @@ -1120,6 +1126,13 @@ class BlockManager
return mWindowBlockManagers.at(windowSize).getPool(relativePoolIndex);
}

void resetReuseState()
{
for (auto& [windowSize, manager] : mWindowBlockManagers)
{
manager.resetReuseState();
}
}
private:
[[nodiscard]] WindowBlockManager const& windowManagerByLayer(SizeType32 layerIdx) const
{
Expand Down Expand Up @@ -1290,6 +1303,7 @@ class BaseKVCacheManager

virtual void refreshBlocks() = 0;
virtual void flushIterationEvents() = 0;
virtual void resetReuseState() = 0;

[[nodiscard]] static SizeType32 getSinkBubbleLength(SizeType32 sinkTokenLen, SizeType32 tokensPerBlock);

Expand Down Expand Up @@ -1633,6 +1647,11 @@ class KVCacheManager : public BaseKVCacheManager
mBlockManager.flushIterationEvents();
}

void resetReuseState() override
{
mBlockManager.resetReuseState();
}

/// @brief Finds the maximum attention window that can be used on a sequence, given some kv-cache block capacity.
///
/// @param inputLength The number of input tokens in the sequence.
Expand Down
3 changes: 2 additions & 1 deletion cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,8 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m)
.def("get_cache_block_ids", &BaseKVCacheManager::getCacheBlockIds)
.def("get_batch_cache_block_ids", &BaseKVCacheManager::getBatchCacheBlockIds)
.def("get_newly_allocated_block_ids", &BaseKVCacheManager::getNewlyAllocatedBlockIds)
.def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents);
.def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents)
.def("reset_reuse_state", &BaseKVCacheManager::resetReuseState);

py::enum_<tbk::CacheType>(m, "CacheType")
.value("SELF", tbk::CacheType::kSELF)
Expand Down
20 changes: 13 additions & 7 deletions tensorrt_llm/_torch/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,9 @@ def load_single_module(name, module):
for new_name in params_map[names[-1]]:
fw = filter_weights('.'.join(names[:-1] + [new_name]),
weights)
# tmp fixes to enable partial updates in old path
if not fw:
continue
if new_name in ['k_proj', 'v_proj']:
num_kv_heads_list = [num_kv_heads
] * len(fw) if isinstance(
Expand All @@ -740,15 +743,18 @@ def load_single_module(name, module):
}

module_weights.append(fw)
module.load_weights(weights=module_weights)
if module_weights:
module.load_weights(weights=module_weights)

else:
module_weights = filter_weights(name, weights)
if hasattr(module, 'load_weights'):
module.load_weights(weights=[module_weights])
else:
for n, p in module._parameters.items():
if p is not None:
p.data.copy_(module_weights[n][:])
if module_weights:
if hasattr(module, 'load_weights'):
module.load_weights(weights=[module_weights])
else:
for n, p in module._parameters.items():
if p is not None:
p.data.copy_(module_weights[n][:])

if os.environ.get("TRT_LLM_DISABLE_LOAD_WEIGHTS_IN_PARALLEL",
False) in ["True", "true", "1", "yes", "y"]:
Expand Down
9 changes: 6 additions & 3 deletions tensorrt_llm/_torch/pyexecutor/llm_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,16 @@ def __init__(self,
return_log_probs: bool = False,
return_context_logits: bool = False,
return_generation_logits: bool = False,
exclude_last_generation_logits: bool = False):
exclude_last_generation_logits: bool = False,
success: bool = False):
self._streaming = streaming
self._context_logits = LogitsStorage(
prompt_len, use_device_memory) if return_context_logits else None
self._generation_logits = LogitsStorage(
max_new_tokens, use_device_memory, exclude_last_generation_logits
) if return_generation_logits else None
self._log_probs = LogProbStorage() if return_log_probs else None
self._success = success

def append_context_logits(self, context_logits: torch.Tensor):
if self._context_logits:
Expand Down Expand Up @@ -246,8 +248,9 @@ def __getattr__(self, item):
return getattr(result, item)

def deserialize(self):
self._result = tensorrt_llm.bindings.executor.deserialize_result(
self._result)
if self._result is not None:
self._result = tensorrt_llm.bindings.executor.deserialize_result(
self._result)


@dataclass
Expand Down
1 change: 0 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1087,7 +1087,6 @@ def init_meta_tensor(t: torch.Tensor):
weights = load_weights(model.llm_checkpoint_dir)
else:
weights = load_weights(checkpoint_dir)

model.load_weights(weights)

if self.spec_config is not None and self.spec_config.spec_dec_mode.need_load_draft_weights(
Expand Down
Loading
Loading