Skip to content

Commit 8445321

Browse files
authored
Embedding out dimension (#276)
Added embedding output dimension to Item nets
1 parent 527c062 commit 8445321

File tree

3 files changed

+53
-0
lines changed

3 files changed

+53
-0
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
88

99
### Added
1010
- `SimilarityModuleBase`, `DistanceSimilarityModule`, similarity module to `TransformerTorchBackbone`, parameters to transformer-based models `similarity_module_type`, `similarity_module_kwargs` ([#272](https://github.com/MobileTeleSystems/RecTools/pull/272))
11+
- `out_dim` property to `IdEmbeddingsItemNet`, `CatFeaturesItemNet` and `SumOfEmbeddingsConstructor` ([#276](https://github.com/MobileTeleSystems/RecTools/pull/276))
1112

1213
## [0.12.0] - 24.02.2025
1314

rectools/models/nn/item_net.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@ def get_all_embeddings(self) -> torch.Tensor:
4646
"""Return item embeddings."""
4747
raise NotImplementedError()
4848

49+
@property
50+
def out_dim(self) -> int:
51+
"""Return item embedding output dimension."""
52+
raise NotImplementedError()
53+
4954
@property
5055
def device(self) -> torch.device:
5156
"""Return ItemNet device."""
@@ -222,6 +227,11 @@ def from_dataset_schema(
222227
)
223228
return None
224229

230+
@property
231+
def out_dim(self) -> int:
232+
"""Return categorical item embedding output dimension."""
233+
return self.embedding_bag.embedding_dim
234+
225235

226236
class IdEmbeddingsItemNet(ItemNetBase):
227237
"""
@@ -317,6 +327,11 @@ def from_dataset_schema(
317327
n_items = dataset_schema.items.n_hot
318328
return cls(n_factors, n_items, dropout_rate)
319329

330+
@property
331+
def out_dim(self) -> int:
332+
"""Return item embedding output dimension."""
333+
return self.ids_emb.embedding_dim
334+
320335

321336
class ItemNetConstructorBase(ItemNetBase):
322337
"""
@@ -467,3 +482,8 @@ def forward(self, items: torch.Tensor) -> torch.Tensor:
467482
item_emb = self.item_net_blocks[idx_block](items)
468483
item_embs.append(item_emb)
469484
return torch.sum(torch.stack(item_embs, dim=0), dim=0)
485+
486+
@property
487+
def out_dim(self) -> int:
488+
"""Return item net constructor output dimension."""
489+
return self.item_net_blocks[0].out_dim

tests/models/nn/test_item_net.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,12 @@ def test_embedding_shape_after_model_pass(self, n_items: int, n_factors: int) ->
6060
expected_item_ids = item_id_embeddings(items)
6161
assert expected_item_ids.shape == (n_items, n_factors)
6262

63+
@pytest.mark.parametrize("n_factors", ((2), (10)))
64+
def test_out_dim(self, n_factors: int) -> None:
65+
item_id_embeddings = IdEmbeddingsItemNet.from_dataset(DATASET, n_factors=n_factors, dropout_rate=0.5)
66+
out_dim = item_id_embeddings.out_dim
67+
assert out_dim == n_factors
68+
6369

6470
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
6571
class TestCatFeaturesItemNet:
@@ -295,6 +301,15 @@ def test_warns_when_dataset_schema_categorical_features_are_none(self) -> None:
295301
"""
296302
)
297303

304+
@pytest.mark.parametrize("n_factors", ((2), (10)))
305+
def test_out_dim(self, dataset_item_features: Dataset, n_factors: int) -> None:
306+
cat_item_embeddings = CatFeaturesItemNet.from_dataset(
307+
dataset_item_features, n_factors=n_factors, dropout_rate=0.5
308+
)
309+
assert isinstance(cat_item_embeddings, CatFeaturesItemNet)
310+
out_dim = cat_item_embeddings.out_dim
311+
assert out_dim == n_factors
312+
298313

299314
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
300315
class TestSumOfEmbeddingsConstructor:
@@ -485,3 +500,20 @@ def test_raise_when_no_item_net_blocks(
485500
SumOfEmbeddingsConstructor.from_dataset(
486501
ds, n_factors=10, dropout_rate=0.5, item_net_block_types=item_net_block_types
487502
)
503+
504+
@pytest.mark.parametrize(
505+
"item_net_block_types,n_factors",
506+
(
507+
((IdEmbeddingsItemNet,), 8),
508+
((IdEmbeddingsItemNet, CatFeaturesItemNet), 16),
509+
((CatFeaturesItemNet,), 16),
510+
),
511+
)
512+
def test_out_dim(
513+
self, dataset_item_features: Dataset, item_net_block_types: tp.Sequence[tp.Type[ItemNetBase]], n_factors: int
514+
) -> None:
515+
item_net = SumOfEmbeddingsConstructor.from_dataset(
516+
dataset_item_features, n_factors=n_factors, dropout_rate=0.5, item_net_block_types=item_net_block_types
517+
)
518+
out_dim = item_net.out_dim
519+
assert out_dim == n_factors

0 commit comments

Comments
 (0)