-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Expand file tree
/
Copy pathevictionPolicy.h
More file actions
116 lines (95 loc) · 5 KB
/
evictionPolicy.h
File metadata and controls
116 lines (95 loc) · 5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
#include <chrono>
#include <vector>
using namespace tensorrt_llm::batch_manager::kv_cache_manager;
namespace tensorrt_llm::batch_manager::eviction_policy
{
class BaseEvictionPolicy
{
public:
virtual ~BaseEvictionPolicy() = default;
// TODO(TRTLLM-1564): Don't use a separate `initialize` function. Ensure eviction policies can't be in-between a
// state of construction and initialization.
virtual void initialize(std::vector<BlockPtr>& mAllBlocksById, std::vector<SizeType32> sizes,
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority)
= 0;
/// @brief Get a free block from the specified cache level
/// @returns The pointer to the free block, along with whether it can be offloaded
virtual std::tuple<BlockPtr, bool> getFreeBlock(SizeType32 cacheLevel) = 0;
/// @brief Release a block. Prioritize the block for eviction if toFront=true
virtual void releaseBlock(BlockPtr block) = 0;
virtual void releaseBlock(BlockPtr block, bool toFront) = 0;
/// @brief Get the amount of free blocks in the primary memory pool
virtual SizeType32 getNumFreeBlocks(SizeType32 cacheLevel) = 0;
/// @brief Claim a free block. Called when the cache manager allocates or reuses a new block
virtual void claimBlock(BlockPtr block) = 0;
virtual void claimBlock(BlockPtr block, std::optional<executor::RetentionPriority> priority,
std::optional<std::chrono::milliseconds> durationMs)
= 0;
/// @brief Perform any per-iteration bookkeeping
virtual void refresh() = 0;
virtual bool verifyQueueIntegrity() = 0;
};
struct ExpiringBlockComparator
{
bool operator()(BlockPtr const& a, BlockPtr const& b) const
{
// If two blocks expire in the same millisecond, their expiration times will be equal. As a fallback, check the
// raw pointer values.
return a->getExpirationTime() != b->getExpirationTime() ? a->getExpirationTime() < b->getExpirationTime()
: a.get() < b.get();
}
};
class LRUEvictionPolicy : public BaseEvictionPolicy
{
public:
void initialize(std::vector<BlockPtr>& mAllBlocksById, std::vector<SizeType32> sizes,
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority) override;
std::tuple<BlockPtr, bool> getFreeBlock(SizeType32 cacheLevel) override;
void releaseBlock(BlockPtr block) override;
void releaseBlock(BlockPtr block, bool toFront) override;
SizeType32 getNumFreeBlocks(SizeType32 cacheLevel) override;
void claimBlock(BlockPtr block) override;
void claimBlock(BlockPtr block, std::optional<executor::RetentionPriority> priority,
std::optional<std::chrono::milliseconds> durationMs) override;
// Check the expiring blocks heap, and move expired blocks back to the default queue.
void refresh() override;
// Making this public and virtual makes it possible to test.
[[nodiscard]] virtual std::chrono::steady_clock::time_point::duration getTime() const;
bool verifyQueueIntegrity() override;
private:
//! \brief Add block to free block queue. Records all info needed to remove block from queue
void addToFreeBlockQueue(BlockPtr block, bool toFront);
//! \brief Remove block from free block queue, using info stored when block was added. It is always safe to call
//! this method \param block The block to be removed from free blocks queue. NOOP if block is not currently in queue
//! \return True if block was removed from free queue.
[[nodiscard]] bool removeFromFreeBlockQueue(BlockPtr block);
private:
// Queues of available leaf blocks, split by cache level and priority level
std::vector<std::vector<FreeBlocksQueue>> mFreeQueues;
// Iterators to block entries in mFreeQueues. Holds ALL arguments needed to remove block from free queue
std::vector<std::optional<std::tuple<SizeType32, SizeType32, FreeBlocksQueue::iterator>>> mFreeBlockIterators;
// Amount of free blocks at each cache level
std::vector<SizeType32> mNumFreeBlocksPerLevel;
// Secondary offload threshold. Blocks below this priority won't be evicted.
executor::RetentionPriority mSecondaryOffloadMinPriority;
// Heap of block times
std::set<BlockPtr, ExpiringBlockComparator> mExpiringBlockHeap;
};
} // namespace tensorrt_llm::batch_manager::eviction_policy