Skip to content

Commit 3f94743

Browse files
malaybagfacebook-github-bot
authored andcommitted
Add check for extra values during unflattening (#3301)
Summary: Pull Request resolved: #3301 As discussed in the previous diff, adding input sanity check Reviewed By: iamzainhuda Differential Revision: D80642205 fbshipit-source-id: 8f790f1fcad54a6b0e5cb3432bf251470da6d24e
1 parent 39c11c1 commit 3f94743

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

torchrec/sparse/jagged_tensor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3067,6 +3067,10 @@ def _kjt_unflatten(
30673067
) -> KeyedJaggedTensor:
30683068
if len(values) < len(KeyedJaggedTensor._fields):
30693069
values.extend([None] * (len(KeyedJaggedTensor._fields) - len(values)))
3070+
elif len(values) > len(KeyedJaggedTensor._fields):
3071+
raise ValueError(
3072+
f"Too many values provided for KeyedJaggedTensor: {len(values)} vs {len(KeyedJaggedTensor._fields)}"
3073+
)
30703074
return KeyedJaggedTensor(
30713075
context,
30723076
*values[:-2],

torchrec/sparse/tests/test_keyed_jagged_tensor.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,18 @@ def test_pytree_kjt(self) -> None:
172172
self.assertIsNone(kjt_0.inverse_indices_or_none())
173173
self.assertIsNone(kjt_1.inverse_indices_or_none())
174174

175+
kjt_0 = KeyedJaggedTensor(
176+
values=values,
177+
keys=keys,
178+
offsets=offsets,
179+
weights=weights,
180+
)
181+
elems, spec = pytree.tree_flatten(kjt_0)
182+
183+
# Simulate extra fields
184+
with self.assertRaises(ValueError):
185+
kjt_1 = pytree.tree_unflatten(elems + elems, spec)
186+
175187
def test_to_dict_vb(self) -> None:
176188
values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
177189
weights = torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5])

0 commit comments

Comments
 (0)