Skip to content

Commit bbf0ea2

Browse files
committed
Adding MaybeAgg for use in places like MLPBlock
1 parent 68de7f9 commit bbf0ea2

File tree

2 files changed

+99
-3
lines changed

2 files changed

+99
-3
lines changed

merlin/models/torch/transforms/agg.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict
1+
from typing import Dict, Union
22

33
import torch
44
from torch import nn
@@ -130,3 +130,64 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor:
130130
raise RuntimeError("Input tensor shapes don't match for stacking.")
131131

132132
return torch.stack(sorted_tensors, dim=self.dim)
133+
134+
135+
class MaybeAgg(nn.Module):
136+
"""
137+
This class is designed to conditionally apply an aggregation operation
138+
(e.g., Stack or Concat) on a tensor or a dictionary of tensors.
139+
140+
Parameters
141+
----------
142+
agg : nn.Module
143+
The aggregation operation to be applied.
144+
145+
Examples
146+
--------
147+
>>> stack = Stack(dim=0)
148+
>>> maybe_agg = MaybeAgg(agg=stack)
149+
>>> tensor1 = torch.tensor([[1, 2], [3, 4]])
150+
>>> tensor2 = torch.tensor([[5, 6], [7, 8]])
151+
>>> input_dict = {"tensor1": tensor1, "tensor2": tensor2}
152+
>>> output = maybe_agg(input_dict)
153+
>>> print(output)
154+
tensor([[[1, 2],
155+
[3, 4]],
156+
157+
[[5, 6],
158+
[7, 8]]])
159+
160+
>>> tensor = torch.tensor([1, 2, 3])
161+
>>> output = maybe_agg(tensor)
162+
>>> print(output)
163+
tensor([1, 2, 3])
164+
"""
165+
166+
def __init__(self, agg: nn.Module):
167+
super().__init__()
168+
self.agg = agg
169+
170+
def forward(self, inputs: Union[Dict[str, torch.Tensor], torch.Tensor]) -> torch.Tensor:
171+
"""
172+
Conditionally applies the aggregation operation on the inputs.
173+
174+
Parameters
175+
----------
176+
inputs : Union[Dict[str, torch.Tensor], torch.Tensor]
177+
Inputs to be aggregated. If inputs is a dictionary of tensors,
178+
the aggregation operation will be applied. If inputs is a single tensor,
179+
it will be returned as is.
180+
181+
Returns
182+
-------
183+
torch.Tensor
184+
Aggregated tensor if inputs is a dictionary, otherwise the input tensor.
185+
"""
186+
187+
if torch.jit.isinstance(inputs, Dict[str, torch.Tensor]):
188+
return self.agg(inputs)
189+
190+
if not torch.jit.isinstance(inputs, torch.Tensor):
191+
raise RuntimeError("Inputs must be either a dictionary of tensors or a single tensor.")
192+
193+
return inputs

tests/unit/torch/transforms/test_agg.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch
33

44
from merlin.models.torch.block import Block
5-
from merlin.models.torch.transforms.agg import Concat, Stack
5+
from merlin.models.torch.transforms.agg import Concat, MaybeAgg, Stack
66
from merlin.models.torch.utils import module_utils
77

88

@@ -43,7 +43,7 @@ def test_from_registry(self):
4343
"a": torch.randn(2, 3),
4444
"b": torch.randn(2, 4),
4545
}
46-
output = block(input_tensors)
46+
output = module_utils.module_test(block, input_tensors)
4747
assert output.shape == (2, 7)
4848

4949

@@ -84,3 +84,38 @@ def test_from_registry(self):
8484
}
8585
output = block(input_tensors)
8686
assert output.shape == (2, 2, 3)
87+
88+
89+
class TestMaybeAgg:
90+
def test_with_single_tensor(self):
91+
tensor = torch.tensor([1, 2, 3])
92+
stack = Stack(dim=0)
93+
maybe_agg = MaybeAgg(agg=stack)
94+
95+
output = module_utils.module_test(maybe_agg, tensor)
96+
assert torch.equal(output, tensor)
97+
98+
def test_with_dict(self):
99+
stack = Stack(dim=0)
100+
maybe_agg = MaybeAgg(agg=stack)
101+
102+
tensor1 = torch.tensor([[1, 2], [3, 4]])
103+
tensor2 = torch.tensor([[5, 6], [7, 8]])
104+
input_dict = {"tensor1": tensor1, "tensor2": tensor2}
105+
expected_output = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
106+
output = module_utils.module_test(maybe_agg, input_dict)
107+
108+
assert torch.equal(output, expected_output)
109+
110+
def test_with_incompatible_dict(self):
111+
concat = Concat(dim=0)
112+
maybe_agg = MaybeAgg(agg=concat)
113+
114+
tensor1 = torch.tensor([1, 2, 3])
115+
tensor2 = torch.tensor([4, 5])
116+
input_dict = {"tensor1": (tensor1, tensor2)}
117+
118+
with pytest.raises(
119+
RuntimeError, match="Inputs must be either a dictionary of tensors or a single tensor"
120+
):
121+
maybe_agg(input_dict)

0 commit comments

Comments
 (0)