Skip to content

Commit 4a88713

Browse files
emlinmeta-codesync[bot]
authored andcommitted
change from first element to a random element for cache missing items (#4955)
Summary: Pull Request resolved: #4955 X-link: https://github.com/facebookresearch/FBGEMM/pull/1974 In inference zch backend, we cannot use initializer to randomized init value for cache missing items, as the intializer does not work in parallel read and write mode. The current behavior is to always get the first item in hash map, but that has less randmization. This diff added a randmization for cache missing ids, also add a log to show the missing ids in every batch. update: - changed back to use hashmap size, instead of allocated block for randomization base - check if the block is used the reason is that, there could be blocks allocated, but never being used, in that case, if we choose that value, it will return all 0s. Verified the new change in VG, and no performance reggression. Reviewed By: EddyLXJ, steven1327 Differential Revision: D83612329 fbshipit-source-id: 23e7f0d1e249c9a3117800c6f992104900fca748
1 parent 33e61aa commit 4a88713

File tree

2 files changed

+98
-4
lines changed

2 files changed

+98
-4
lines changed

fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_inference_embedding.h

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include <thrift/lib/cpp2/protocol/CompactProtocol.h>
2323
#include <thrift/lib/cpp2/protocol/Serializer.h>
2424
#include <torch/script.h>
25+
#include <random>
2526
#include "common/time/Time.h"
2627

2728
#include "../ssd_split_embeddings_cache/initializer.h"
@@ -419,9 +420,36 @@ class DramKVInferenceEmbedding {
419420
before_read_lock_ts;
420421

421422
if (!wlmap->empty()) {
422-
row_storage_data_ptr =
423-
FixedBlockPool::data_ptr<weight_type>(
424-
wlmap->begin()->second);
423+
// Simple block-based randomization using get_block with
424+
// cursor
425+
auto* pool = kv_store_.pool_by(shard_id);
426+
427+
// Random starting cursor based on map size for good
428+
// entropy
429+
size_t random_start =
430+
folly::Random::rand32(wlmap->size());
431+
432+
// Try to find a used block starting from random
433+
// position
434+
weight_type* block = nullptr;
435+
for (int attempts = 0; attempts < 16; ++attempts) {
436+
block = pool->template get_block<weight_type>(
437+
random_start + attempts);
438+
if (block != nullptr) {
439+
// Block is used (not null)
440+
row_storage_data_ptr =
441+
FixedBlockPool::data_ptr<weight_type>(block);
442+
break;
443+
}
444+
}
445+
446+
// Fallback: if no used block found, use first element
447+
// from map
448+
if (block == nullptr) {
449+
row_storage_data_ptr =
450+
FixedBlockPool::data_ptr<weight_type>(
451+
wlmap->begin()->second);
452+
}
425453
} else {
426454
const auto& init_storage =
427455
initializers_[shard_id]->row_storage_;
@@ -526,7 +554,9 @@ class DramKVInferenceEmbedding {
526554
read_lookup_cache_total_duration / num_shards_;
527555
read_acquire_lock_avg_duration_ +=
528556
read_acquire_lock_total_duration / num_shards_;
529-
read_missing_load_avg_ += read_missing_load / num_shards_;
557+
LOG_EVERY_MS(INFO, 5000)
558+
<< "get_kv_db_async total read_missing_load per batch: "
559+
<< read_missing_load;
530560
return std::vector<folly::Unit>(results.size());
531561
});
532562
};

fbgemm_gpu/test/tbe/dram_kv/dram_kv_inference_test.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,3 +249,67 @@ def reader_thread() -> None: # pyre-ignore
249249
self.assertTrue(equal_one_of(embs[5, :4], possible_embs))
250250
reader_thread.join()
251251
self.assertFalse(reader_failed_event.is_set())
252+
253+
def test_randomized_cache_miss_initialization(self) -> None:
254+
"""Test that cache misses use randomized data from existing blocks."""
255+
num_shards = 8
256+
uniform_init_lower: float = -0.01
257+
uniform_init_upper: float = 0.01
258+
259+
# Create DRAM KV inference cache
260+
kv_embedding_cache = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper(
261+
num_shards, uniform_init_lower, uniform_init_upper
262+
)
263+
kv_embedding_cache.init(
264+
[(32, 4, SparseType.FP16.as_int())],
265+
32,
266+
4,
267+
torch.tensor([0, 100], dtype=torch.int64),
268+
)
269+
270+
# Setup: Populate the cache with many initial values for better randomization diversity
271+
# Use 400 setup items to ensure each shard (8 shards) gets ~50 entries for good randomization
272+
setup_indices = torch.arange(0, 400, dtype=torch.int64) # 400 setup items
273+
setup_weights = torch.randint(
274+
1, 255, (400, 32), dtype=torch.uint8
275+
) # Non-zero values to ensure randomization source
276+
print(f"setup_weights: {setup_weights}")
277+
278+
# Populate cache
279+
kv_embedding_cache.set_embeddings(setup_indices, setup_weights)
280+
281+
# Execute: Request cache misses multiple times - these should get randomized initialization
282+
# Use indices outside the range [0, 399] to ensure they are actual cache misses
283+
miss_indices = torch.tensor([500, 501, 502, 503, 504], dtype=torch.int64)
284+
285+
# Get the cache miss results multiple times to check for randomization
286+
results = []
287+
for _ in range(5):
288+
current_output = kv_embedding_cache.get_embeddings(miss_indices)
289+
results.append(current_output.clone())
290+
291+
# Assert: Verify that randomization occurs
292+
# The results should not all be identical if randomization is working
293+
all_identical = True
294+
for i in range(1, len(results)):
295+
if not torch.equal(
296+
results[0][:, :4], results[i][:, :4]
297+
): # Only check first 4 columns (actual data)
298+
all_identical = False
299+
break
300+
301+
# Since we're using randomization, results should be different
302+
# Note: There's a small chance they could be identical by random chance,
303+
# but with 5 trials of 5 vectors of 4 bytes, this is extremely unlikely
304+
self.assertFalse(
305+
all_identical,
306+
"Randomized cache miss initialization should produce different results",
307+
)
308+
309+
# All results should be non-zero (since we populated the cache with non-zero random values)
310+
for result in results:
311+
# Check that at least some values are non-zero (indicating data came from existing blocks)
312+
self.assertTrue(
313+
torch.any(result[:, :4] != 0),
314+
"Cache miss results should contain non-zero values when cache has data",
315+
)

0 commit comments

Comments
 (0)