Skip to content

Commit 61bf3cf

Browse files
Yixin Baometa-codesync[bot]
authored andcommitted
Add a flag to disable fallback in mpzch and return empty rows for cache missed ids. (#3325)
Summary: Pull Request resolved: #3325 X-link: #3325 Return empty embedding rows for missed id in ZCH module. Reviewed By: zlzhao1104 Differential Revision: D80683711 fbshipit-source-id: 63b875bf50ffc36457fda977befbd01ac0c46397
1 parent cd493f1 commit 61bf3cf

File tree

2 files changed

+87
-2
lines changed

2 files changed

+87
-2
lines changed

torchrec/modules/mc_modules.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,11 @@ def _cat_jagged_values(jd: Dict[str, JaggedTensor]) -> torch.Tensor:
5454
return torch.cat([jt.values() for jt in jd.values()])
5555

5656

57+
@torch.fx.wrap
58+
def _cat_jagged_lengths(jd: Dict[str, JaggedTensor]) -> torch.Tensor:
59+
return torch.cat([jt.lengths() for jt in jd.values()])
60+
61+
5762
# TODO: keep the old implementation for backward compatibility and will remove it later
5863
@torch.fx.wrap
5964
def _mcc_lazy_init(
@@ -416,10 +421,11 @@ def forward(
416421

417422
assert output is not None
418423
values: torch.Tensor = _cat_jagged_values(output)
424+
lengths: torch.Tensor = _cat_jagged_lengths(output)
419425
return KeyedJaggedTensor(
420426
keys=features.keys(),
421427
values=values,
422-
lengths=features.lengths(),
428+
lengths=lengths,
423429
weights=features.weights_or_none(),
424430
)
425431

torchrec/modules/tests/test_mc_modules.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,23 @@
1111
from typing import Dict
1212

1313
import torch
14+
from torchrec.fb.modules.hash_mc_evictions import (
15+
HashZchEvictionConfig,
16+
HashZchEvictionPolicyName,
17+
)
18+
from torchrec.fb.modules.hash_mc_modules import HashZchManagedCollisionModule
19+
from torchrec.modules.embedding_configs import EmbeddingConfig
1420
from torchrec.modules.mc_modules import (
1521
average_threshold_filter,
1622
DistanceLFU_EvictionPolicy,
1723
dynamic_threshold_filter,
1824
LFU_EvictionPolicy,
1925
LRU_EvictionPolicy,
26+
ManagedCollisionCollection,
2027
MCHManagedCollisionModule,
2128
probabilistic_threshold_filter,
2229
)
23-
from torchrec.sparse.jagged_tensor import JaggedTensor
30+
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor
2431

2532

2633
class TestEvictionPolicy(unittest.TestCase):
@@ -427,3 +434,75 @@ def test_fx_jit_script_not_training(self) -> None:
427434
model.train(False)
428435
gm = torch.fx.symbolic_trace(model)
429436
torch.jit.script(gm)
437+
438+
def test_mc_module_forward(self) -> None:
439+
embedding_configs = [
440+
EmbeddingConfig(
441+
name="t1",
442+
num_embeddings=100,
443+
embedding_dim=8,
444+
feature_names=["f1", "f2"],
445+
),
446+
EmbeddingConfig(
447+
name="t2",
448+
num_embeddings=100,
449+
embedding_dim=8,
450+
feature_names=["f3", "f4"],
451+
),
452+
]
453+
454+
mc_modules = {
455+
"t1": HashZchManagedCollisionModule(
456+
zch_size=100,
457+
device=torch.device("cpu"),
458+
total_num_buckets=1,
459+
eviction_policy_name=HashZchEvictionPolicyName.SINGLE_TTL_EVICTION,
460+
eviction_config=HashZchEvictionConfig(
461+
features=[],
462+
single_ttl=10,
463+
),
464+
),
465+
"t2": HashZchManagedCollisionModule(
466+
zch_size=100,
467+
device=torch.device("cpu"),
468+
total_num_buckets=1,
469+
eviction_policy_name=HashZchEvictionPolicyName.SINGLE_TTL_EVICTION,
470+
eviction_config=HashZchEvictionConfig(
471+
features=[],
472+
single_ttl=10,
473+
),
474+
),
475+
}
476+
for mc_module in mc_modules.values():
477+
mc_module.reset_inference_mode()
478+
mc_ebc = ManagedCollisionCollection(
479+
# Pyre-ignore [6]: In call `ManagedCollisionCollection.__init__`, for argument `managed_collision_modules`, expected `Dict[str, ManagedCollisionModule]` but got `Dict[str, HashZchManagedCollisionModule]`
480+
managed_collision_modules=mc_modules,
481+
embedding_configs=embedding_configs,
482+
)
483+
kjt = KeyedJaggedTensor(
484+
keys=["f1", "f2", "f3", "f4"],
485+
values=torch.cat(
486+
[
487+
torch.arange(0, 20, 2, dtype=torch.int64, device="cpu"),
488+
torch.arange(30, 60, 3, dtype=torch.int64, device="cpu"),
489+
torch.arange(20, 30, 2, dtype=torch.int64, device="cpu"),
490+
torch.arange(0, 20, 2, dtype=torch.int64, device="cpu"),
491+
]
492+
),
493+
lengths=torch.cat(
494+
[
495+
torch.tensor([4, 6], dtype=torch.int64, device="cpu"),
496+
torch.tensor([5, 5], dtype=torch.int64, device="cpu"),
497+
torch.tensor([1, 4], dtype=torch.int64, device="cpu"),
498+
torch.tensor([7, 3], dtype=torch.int64, device="cpu"),
499+
]
500+
),
501+
)
502+
res = mc_ebc.forward(kjt)
503+
self.assertTrue(torch.equal(res.lengths(), kjt.lengths()))
504+
self.assertTrue(
505+
torch.equal(
506+
res.lengths(), torch.tensor([4, 6, 5, 5, 1, 4, 7, 3], dtype=torch.int64)
507+
)
508+
)

0 commit comments

Comments
 (0)