Skip to content

Commit 7f6e7fb

Browse files
malaybagfacebook-github-bot
authored andcommitted
Fix missing values in kjt unflatten (#3299)
Summary: Pull Request resolved: #3299 There is a code path where stride_per_key_per_rank and inverse_indices values are not passed to unflatten method. In that case we will assign None values for those two fields. pytree flatten was not tested for KJT. Adding test method to test flatten and unflatten of KJT. Reviewed By: jd7-tr Differential Revision: D80593739 fbshipit-source-id: e2d7f85715e4d49c4791266db6ee06651b43ab57
1 parent 4f9e75e commit 7f6e7fb

File tree

2 files changed

+54
-29
lines changed

2 files changed

+54
-29
lines changed

torchrec/sparse/jagged_tensor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3065,6 +3065,8 @@ def _kjt_unflatten(
30653065
values: List[Optional[torch.Tensor]],
30663066
context: List[str], # context is _keys
30673067
) -> KeyedJaggedTensor:
3068+
if len(values) < len(KeyedJaggedTensor._fields):
3069+
values.extend([None] * (len(KeyedJaggedTensor._fields) - len(values)))
30683070
return KeyedJaggedTensor(
30693071
context,
30703072
*values[:-2],

torchrec/sparse/tests/test_keyed_jagged_tensor.py

Lines changed: 52 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -112,42 +112,65 @@ def test_to_dict(self) -> None:
112112
torch.equal(j1.values(), torch.Tensor([4.0, 5.0, 6.0, 7.0, 8.0]))
113113
)
114114

115-
def test_pytree(self) -> None:
116-
values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
117-
j0 = JaggedTensor(
115+
def test_pytree_kjt(self) -> None:
116+
values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
117+
weights = torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5])
118+
keys = ["index_0", "index_1"]
119+
offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8])
120+
stride_per_key_per_rank = [[2], [4]]
121+
inverse_indices = torch.tensor([[0, 1, 0], [0, 0, 0]])
122+
123+
kjt_0 = KeyedJaggedTensor(
118124
values=values,
119-
lengths=torch.IntTensor([1, 0, 2, 3]),
125+
keys=keys,
126+
offsets=offsets,
127+
weights=weights,
128+
stride_per_key_per_rank=stride_per_key_per_rank,
129+
inverse_indices=(keys, inverse_indices),
120130
)
121-
elems, spec = pytree.tree_flatten(j0)
122-
j1 = pytree.tree_unflatten(elems, spec)
131+
elems, spec = pytree.tree_flatten(kjt_0)
132+
kjt_1 = pytree.tree_unflatten(elems, spec)
123133

124-
self.assertTrue(torch.equal(j0.lengths(), j1.lengths()))
125-
self.assertIsNone(j0.weights_or_none())
126-
self.assertIsNone(j1.weights_or_none())
127-
self.assertTrue(torch.equal(j0.values(), j1.values()))
134+
self.assertTrue(torch.equal(kjt_0.values(), kjt_1.values()))
135+
self.assertIsNone(kjt_0.lengths_or_none())
136+
self.assertIsNone(kjt_1.lengths_or_none())
137+
self.assertTrue(torch.equal(kjt_0.weights(), kjt_1.weights()))
138+
self.assertTrue(torch.equal(kjt_0.offsets(), kjt_1.offsets()))
139+
self.assertEqual(kjt_0.keys(), kjt_1.keys())
140+
self.assertEqual(
141+
kjt_0.stride_per_key_per_rank(), kjt_1.stride_per_key_per_rank()
142+
)
143+
self.assertEqual(kjt_0.inverse_indices()[0], kjt_1.inverse_indices()[0])
144+
self.assertTrue(
145+
torch.equal(kjt_0.inverse_indices()[1], kjt_1.inverse_indices()[1])
146+
)
128147

129-
values = [
130-
torch.Tensor([1.0]),
131-
torch.Tensor(),
132-
torch.Tensor([7.0, 8.0]),
133-
torch.Tensor([10.0, 11.0, 12.0]),
134-
]
135-
weights = [
136-
torch.Tensor([1.0]),
137-
torch.Tensor(),
138-
torch.Tensor([7.0, 8.0]),
139-
torch.Tensor([10.0, 11.0, 12.0]),
140-
]
141-
j0 = JaggedTensor.from_dense(
148+
kjt_0 = KeyedJaggedTensor(
142149
values=values,
150+
keys=keys,
151+
offsets=offsets,
143152
weights=weights,
144153
)
145-
elems, spec = pytree.tree_flatten(j0)
146-
j1 = pytree.tree_unflatten(elems, spec)
147-
148-
self.assertTrue(torch.equal(j0.lengths(), j1.lengths()))
149-
self.assertTrue(torch.equal(j0.weights(), j1.weights()))
150-
self.assertTrue(torch.equal(j0.values(), j1.values()))
154+
elems, spec = pytree.tree_flatten(kjt_0)
155+
156+
# Simulate missing stride_per_key_per_rank and inverse_indices
157+
spec = pytree.TreeSpec(
158+
type=spec.type,
159+
context=spec.context,
160+
children_specs=spec.children_specs[:4],
161+
)
162+
kjt_1 = pytree.tree_unflatten(elems[:4], spec)
163+
164+
self.assertTrue(torch.equal(kjt_0.values(), kjt_1.values()))
165+
self.assertIsNone(kjt_0.lengths_or_none())
166+
self.assertIsNone(kjt_1.lengths_or_none())
167+
self.assertTrue(torch.equal(kjt_0.weights(), kjt_1.weights()))
168+
self.assertTrue(torch.equal(kjt_0.offsets(), kjt_1.offsets()))
169+
self.assertEqual(kjt_0.keys(), kjt_1.keys())
170+
self.assertTrue(len(kjt_0.stride_per_key_per_rank()) == 0)
171+
self.assertTrue(len(kjt_1.stride_per_key_per_rank()) == 0)
172+
self.assertIsNone(kjt_0.inverse_indices_or_none())
173+
self.assertIsNone(kjt_1.inverse_indices_or_none())
151174

152175
def test_to_dict_vb(self) -> None:
153176
values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])

0 commit comments

Comments
 (0)