|
11 | 11 | from typing import Dict
|
12 | 12 |
|
13 | 13 | import torch
|
| 14 | +from torchrec.fb.modules.hash_mc_evictions import ( |
| 15 | + HashZchEvictionConfig, |
| 16 | + HashZchEvictionPolicyName, |
| 17 | +) |
| 18 | +from torchrec.fb.modules.hash_mc_modules import HashZchManagedCollisionModule |
| 19 | +from torchrec.modules.embedding_configs import EmbeddingConfig |
14 | 20 | from torchrec.modules.mc_modules import (
|
15 | 21 | average_threshold_filter,
|
16 | 22 | DistanceLFU_EvictionPolicy,
|
17 | 23 | dynamic_threshold_filter,
|
18 | 24 | LFU_EvictionPolicy,
|
19 | 25 | LRU_EvictionPolicy,
|
| 26 | + ManagedCollisionCollection, |
20 | 27 | MCHManagedCollisionModule,
|
21 | 28 | probabilistic_threshold_filter,
|
22 | 29 | )
|
23 |
| -from torchrec.sparse.jagged_tensor import JaggedTensor |
| 30 | +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor |
24 | 31 |
|
25 | 32 |
|
26 | 33 | class TestEvictionPolicy(unittest.TestCase):
|
@@ -427,3 +434,75 @@ def test_fx_jit_script_not_training(self) -> None:
|
427 | 434 | model.train(False)
|
428 | 435 | gm = torch.fx.symbolic_trace(model)
|
429 | 436 | torch.jit.script(gm)
|
| 437 | + |
| 438 | + def test_mc_module_forward(self) -> None: |
| 439 | + embedding_configs = [ |
| 440 | + EmbeddingConfig( |
| 441 | + name="t1", |
| 442 | + num_embeddings=100, |
| 443 | + embedding_dim=8, |
| 444 | + feature_names=["f1", "f2"], |
| 445 | + ), |
| 446 | + EmbeddingConfig( |
| 447 | + name="t2", |
| 448 | + num_embeddings=100, |
| 449 | + embedding_dim=8, |
| 450 | + feature_names=["f3", "f4"], |
| 451 | + ), |
| 452 | + ] |
| 453 | + |
| 454 | + mc_modules = { |
| 455 | + "t1": HashZchManagedCollisionModule( |
| 456 | + zch_size=100, |
| 457 | + device=torch.device("cpu"), |
| 458 | + total_num_buckets=1, |
| 459 | + eviction_policy_name=HashZchEvictionPolicyName.SINGLE_TTL_EVICTION, |
| 460 | + eviction_config=HashZchEvictionConfig( |
| 461 | + features=[], |
| 462 | + single_ttl=10, |
| 463 | + ), |
| 464 | + ), |
| 465 | + "t2": HashZchManagedCollisionModule( |
| 466 | + zch_size=100, |
| 467 | + device=torch.device("cpu"), |
| 468 | + total_num_buckets=1, |
| 469 | + eviction_policy_name=HashZchEvictionPolicyName.SINGLE_TTL_EVICTION, |
| 470 | + eviction_config=HashZchEvictionConfig( |
| 471 | + features=[], |
| 472 | + single_ttl=10, |
| 473 | + ), |
| 474 | + ), |
| 475 | + } |
| 476 | + for mc_module in mc_modules.values(): |
| 477 | + mc_module.reset_inference_mode() |
| 478 | + mc_ebc = ManagedCollisionCollection( |
| 479 | + # Pyre-ignore [6]: In call `ManagedCollisionCollection.__init__`, for argument `managed_collision_modules`, expected `Dict[str, ManagedCollisionModule]` but got `Dict[str, HashZchManagedCollisionModule]` |
| 480 | + managed_collision_modules=mc_modules, |
| 481 | + embedding_configs=embedding_configs, |
| 482 | + ) |
| 483 | + kjt = KeyedJaggedTensor( |
| 484 | + keys=["f1", "f2", "f3", "f4"], |
| 485 | + values=torch.cat( |
| 486 | + [ |
| 487 | + torch.arange(0, 20, 2, dtype=torch.int64, device="cpu"), |
| 488 | + torch.arange(30, 60, 3, dtype=torch.int64, device="cpu"), |
| 489 | + torch.arange(20, 30, 2, dtype=torch.int64, device="cpu"), |
| 490 | + torch.arange(0, 20, 2, dtype=torch.int64, device="cpu"), |
| 491 | + ] |
| 492 | + ), |
| 493 | + lengths=torch.cat( |
| 494 | + [ |
| 495 | + torch.tensor([4, 6], dtype=torch.int64, device="cpu"), |
| 496 | + torch.tensor([5, 5], dtype=torch.int64, device="cpu"), |
| 497 | + torch.tensor([1, 4], dtype=torch.int64, device="cpu"), |
| 498 | + torch.tensor([7, 3], dtype=torch.int64, device="cpu"), |
| 499 | + ] |
| 500 | + ), |
| 501 | + ) |
| 502 | + res = mc_ebc.forward(kjt) |
| 503 | + self.assertTrue(torch.equal(res.lengths(), kjt.lengths())) |
| 504 | + self.assertTrue( |
| 505 | + torch.equal( |
| 506 | + res.lengths(), torch.tensor([4, 6, 5, 5, 1, 4, 7, 3], dtype=torch.int64) |
| 507 | + ) |
| 508 | + ) |
0 commit comments