Skip to content

Commit 6d8e16a

Browse files
EddyLXJfacebook-github-bot
authored andcommitted
Add util func for kvzch eviction mask (pytorch#4610)
Summary: X-link: pytorch/torchrec#3246 X-link: facebookresearch/FBGEMM#1645 Adding a util func for kvzch to get a eviction mask using inference threshold. This is used in publish. Reviewed By: yixin94 Differential Revision: D79045178
1 parent fda4ccf commit 6d8e16a

File tree

4 files changed

+263
-0
lines changed

4 files changed

+263
-0
lines changed

fbgemm_gpu/fbgemm_gpu/kvzch_util.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import time
8+
9+
import torch
10+
from torchrec.modules.embedding_configs import (
11+
CountBasedEvictionPolicy,
12+
CountTimestampMixedEvictionPolicy,
13+
FeatureL2NormBasedEvictionPolicy,
14+
NoEvictionPolicy,
15+
TimestampBasedEvictionPolicy,
16+
VirtualTableEvictionPolicy,
17+
)
18+
19+
20+
def parse_metadata_tensor(metadata_tensor: torch.Tensor):
21+
"""
22+
Parses a kvzch metadata tensor where each element encodes three pieces of information
23+
packed into a single 64-bit integer.
24+
The 64-bit integer layout is as follows:
25+
- The lower 32 bits (bits 0-31) represent the 'timestamp', stored as a uint32.
26+
This timestamp is typically in seconds and can represent a range of over 120 years.
27+
- The upper 32 bits (bits 32-63) encode two fields packed together:
28+
* The lower 31 bits of this upper half (bits 32-62 overall) represent 'count',
29+
a 31-bit unsigned integer indicating a usage count or score.
30+
* The highest bit of this upper half (bit 63 overall) represents 'used',
31+
a boolean flag indicating whether the block is currently in use.
32+
This function extracts these three components from each 64-bit integer in the tensor:
33+
- 'timestamps' as a uint32 array
34+
- 'counts' as a uint32 array (31 bits used)
35+
- 'used' as a boolean array
36+
Args:
37+
metadata_tensor (torch.Tensor): A 1D tensor of dtype torch.int64, where each
38+
element encodes timestamp, count, and used flag.
39+
Returns:
40+
tuple: (timestamps, counts, used)
41+
- timestamps (tensor): int64 array of timestamps extracted from the tensor.
42+
- counts (tensor): int64 array of counts extracted from the tensor.
43+
- used (tensor): boolean array indicating usage flags extracted from the tensor.
44+
"""
45+
assert metadata_tensor.dtype == torch.int64
46+
timestamps = metadata_tensor & 0xFFFFFFFF # Extract lower 32 bits as timestamp
47+
count_used = (
48+
metadata_tensor >> 32
49+
) # Extract upper 32 bits containing count and used
50+
counts = count_used & 0x7FFFFFFF # Lower 31 bits of upper half as count
51+
used = ((count_used >> 31) & 1).to(
52+
torch.bool
53+
) # Highest bit of upper half as used flag
54+
return timestamps, counts, used
55+
56+
57+
def get_kv_zch_eviction_mask(
58+
metadata_tensor: torch.Tensor,
59+
eviction_policy: VirtualTableEvictionPolicy,
60+
):
61+
"""
62+
Returns a boolean mask indicating which blocks should be evicted from the KV cache.
63+
The eviction policy is determined by the 'eviction_policy' argument.
64+
Args:
65+
metadata_tensor (torch.Tensor): A 1D tensor of dtype torch.int64, where each
66+
element encodes timestamp, count, and used flag.
67+
eviction_policy (VirtualTableEvictionPolicy): The eviction policy to use.
68+
Returns:
69+
torch.Tensor: A 1D boolean tensor of the same size as 'metadata_tensor', where False indicates a block should be evicted.
70+
"""
71+
72+
eviction_mask = torch.ones_like(
73+
metadata_tensor, dtype=torch.bool
74+
) # Initialize mask to True (keep all blocks)
75+
if isinstance(eviction_policy, NoEvictionPolicy):
76+
return eviction_mask
77+
78+
# Parse the metadata tensor to extract timestamps, counts, and used flags
79+
timestamps, counts, _ = parse_metadata_tensor(metadata_tensor)
80+
81+
# Apply the eviction policy to determine which blocks should be evicted
82+
# Check which policy is being used
83+
if isinstance(eviction_policy, CountBasedEvictionPolicy):
84+
inference_eviction_threshold = eviction_policy.inference_eviction_threshold
85+
eviction_mask = counts >= inference_eviction_threshold
86+
87+
elif isinstance(eviction_policy, TimestampBasedEvictionPolicy):
88+
inference_eviction_ttl_mins = eviction_policy.inference_eviction_ttl_mins
89+
if inference_eviction_ttl_mins != 0: # eviction_ttl_mins == 0 means no eviction
90+
current_time = int(time.time())
91+
eviction_mask = (
92+
current_time - timestamps
93+
) <= inference_eviction_ttl_mins * 60
94+
95+
elif isinstance(eviction_policy, CountTimestampMixedEvictionPolicy):
96+
inference_eviction_threshold = eviction_policy.inference_eviction_threshold
97+
inference_eviction_ttl_mins = eviction_policy.inference_eviction_ttl_mins
98+
current_time = int(time.time())
99+
eviction_ttl_secs = inference_eviction_ttl_mins * 60
100+
if inference_eviction_threshold == 0:
101+
count_mask = torch.ones_like(counts, dtype=torch.bool)
102+
else:
103+
count_mask = counts >= inference_eviction_threshold
104+
105+
if inference_eviction_ttl_mins == 0:
106+
timestamp_mask = torch.ones_like(counts, dtype=torch.bool)
107+
else:
108+
timestamp_mask = (current_time - timestamps) <= eviction_ttl_secs
109+
eviction_mask = count_mask & timestamp_mask
110+
111+
elif isinstance(eviction_policy, FeatureL2NormBasedEvictionPolicy):
112+
# Feature L2 norm-based eviction logic
113+
# No op for now
114+
pass
115+
else:
116+
raise ValueError("Unsupported eviction policy")
117+
118+
return eviction_mask

fbgemm_gpu/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,4 @@ setuptools_git_versioning
2828
tabulate
2929
patchelf
3030
fairscale
31+
torchrec

fbgemm_gpu/requirements_genai.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,4 @@ setuptools_git_versioning
3030
tabulate
3131
patchelf
3232
fairscale
33+
torchrec
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import time
10+
import unittest
11+
12+
import numpy as np
13+
import torch
14+
15+
from fbgemm_gpu.kvzch_util import get_kv_zch_eviction_mask, parse_metadata_tensor
16+
from torchrec.modules.embedding_configs import (
17+
CountBasedEvictionPolicy,
18+
CountTimestampMixedEvictionPolicy,
19+
NoEvictionPolicy,
20+
TimestampBasedEvictionPolicy,
21+
VirtualTableEvictionPolicy,
22+
)
23+
24+
from ..common import gpu_unavailable
25+
26+
27+
@unittest.skipIf(*gpu_unavailable)
28+
class KvzchUtilsTest(unittest.TestCase):
29+
def test_basic_parsing(self) -> None:
30+
"""
31+
Test typical parsing including used=0 and used=1 cases.
32+
"""
33+
# Compose metadata values as 64-bit integers:
34+
# [timestamp=7, count=13, used=0]
35+
v1 = 7 | (13 << 32) # used=0 (highest bit not set)
36+
# [timestamp=42, count=99, used=1]
37+
# Used=1 is highest bit; encode as a negative int64 in Python to avoid overflow
38+
v2 = (42 | (99 << 32)) - (1 << 63) # set highest bit
39+
# [timestamp=0xABCDEF01, count=0x1ABCDE0, used=1]
40+
v3 = (0xABCDEF01 | (0x1ABCDE0 << 32)) - (1 << 63)
41+
vals = [v1, v2, v3]
42+
tensor = torch.tensor(vals, dtype=torch.int64)
43+
44+
timestamps, counts, used = parse_metadata_tensor(tensor)
45+
46+
np.testing.assert_array_equal(
47+
timestamps.numpy(), np.array([7, 42, 0xABCDEF01], dtype=np.uint32)
48+
)
49+
np.testing.assert_array_equal(
50+
counts.numpy(), np.array([13, 99, 0x1ABCDE0], dtype=np.uint32)
51+
)
52+
np.testing.assert_array_equal(
53+
used.numpy(), np.array([False, True, True], dtype=bool)
54+
)
55+
56+
def test_edge_cases(self) -> None:
57+
"""
58+
Test edge cases including all zeros, max values, min values, and different used flags.
59+
"""
60+
# All fields zero, used=0
61+
v1 = 0
62+
# Max timestamp, max count, used=0
63+
v2 = 0xFFFFFFFF | (0x7FFFFFFF << 32) # Used=0 (highest bit = 0)
64+
# Min timestamp, min count, used=1
65+
v3 = 0 - (1 << 63) # All fields 0, only highest bit set (used=1)
66+
67+
vals = [v1, v2, v3]
68+
tensor = torch.tensor(vals, dtype=torch.int64)
69+
70+
timestamps, counts, used = parse_metadata_tensor(tensor)
71+
72+
np.testing.assert_array_equal(
73+
timestamps.numpy(), np.array([0, 0xFFFFFFFF, 0], dtype=np.uint32)
74+
)
75+
np.testing.assert_array_equal(
76+
counts.numpy(), np.array([0, 0x7FFFFFFF, 0], dtype=np.uint32)
77+
)
78+
np.testing.assert_array_equal(
79+
used.numpy(), np.array([False, False, True], dtype=bool)
80+
)
81+
82+
def test_invalid_dtype(self) -> None:
83+
"""
84+
Test that an assertion is raised for wrong dtype.
85+
"""
86+
tensor = torch.tensor([1, 2, 3], dtype=torch.float32)
87+
with self.assertRaises(AssertionError):
88+
parse_metadata_tensor(tensor)
89+
90+
91+
class GetKvZchEvictionMaskTest(unittest.TestCase):
92+
def setUp(self) -> None:
93+
# Prepare some metadata values with timestamp, count, used
94+
# Use negative numbers to represent highest bit set (used=1)
95+
self.vals = [
96+
(100 | (5 << 32)), # used=0
97+
(int(time.time()) - 60 | (10 << 32))
98+
- (1 << 63), # used=1, timestamp 1 min ago
99+
(int(time.time()) - 3600 | (15 << 32))
100+
- (1 << 63), # used=1, timestamp 1 hour ago
101+
]
102+
self.metadata_tensor = torch.tensor(self.vals, dtype=torch.int64)
103+
104+
def test_count_based_eviction(self) -> None:
105+
policy = CountBasedEvictionPolicy(inference_eviction_threshold=10)
106+
mask = get_kv_zch_eviction_mask(self.metadata_tensor, policy)
107+
# counts are 5,10,15; threshold=10; keep counts >= 10
108+
expected = torch.tensor([False, True, True], dtype=torch.bool)
109+
self.assertTrue(torch.equal(mask, expected))
110+
111+
def test_timestamp_based_eviction(self) -> None:
112+
policy = TimestampBasedEvictionPolicy(inference_eviction_ttl_mins=30)
113+
mask = get_kv_zch_eviction_mask(self.metadata_tensor, policy)
114+
# timestamps: 100 (old), now-60s, now-3600s
115+
# TTL=30min=1800s, keep timestamps within 1800s
116+
expected = torch.tensor([False, True, False], dtype=torch.bool)
117+
self.assertTrue(torch.equal(mask, expected))
118+
119+
def test_count_timestamp_mixed_eviction(self) -> None:
120+
policy = CountTimestampMixedEvictionPolicy(
121+
inference_eviction_threshold=10, inference_eviction_ttl_mins=30
122+
)
123+
mask = get_kv_zch_eviction_mask(self.metadata_tensor, policy)
124+
# count mask: counts >= 10 -> [False, True, True]
125+
# timestamp mask: within 1800s -> [False, True, False]
126+
# combined mask = count_mask & timestamp_mask
127+
expected = torch.tensor([False, True, False], dtype=torch.bool)
128+
self.assertTrue(torch.equal(mask, expected))
129+
130+
def test_no_eviction_policy(self) -> None:
131+
policy = NoEvictionPolicy()
132+
mask = get_kv_zch_eviction_mask(self.metadata_tensor, policy)
133+
# No eviction, mask all True
134+
expected = torch.ones_like(self.metadata_tensor, dtype=torch.bool)
135+
self.assertTrue(torch.equal(mask, expected))
136+
137+
def test_unsupported_policy(self) -> None:
138+
class DummyPolicy(VirtualTableEvictionPolicy):
139+
pass
140+
141+
policy = DummyPolicy()
142+
with self.assertRaises(ValueError):
143+
get_kv_zch_eviction_mask(self.metadata_tensor, policy)

0 commit comments

Comments
 (0)