diff --git a/torchrec/distributed/mc_embedding_modules.py b/torchrec/distributed/mc_embedding_modules.py index b817f020a..63673be73 100644 --- a/torchrec/distributed/mc_embedding_modules.py +++ b/torchrec/distributed/mc_embedding_modules.py @@ -303,6 +303,7 @@ def compute_kernels( EmbeddingComputeKernel.FUSED.value, EmbeddingComputeKernel.FUSED_UVM_CACHING.value, EmbeddingComputeKernel.FUSED_UVM.value, + EmbeddingComputeKernel.KEY_VALUE.value, ] def sharding_types(self, compute_device_type: str) -> List[str]: