@@ -112,42 +112,65 @@ def test_to_dict(self) -> None:
112
112
torch .equal (j1 .values (), torch .Tensor ([4.0 , 5.0 , 6.0 , 7.0 , 8.0 ]))
113
113
)
114
114
115
- def test_pytree (self ) -> None :
116
- values = torch .Tensor ([1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 ])
117
- j0 = JaggedTensor (
115
+ def test_pytree_kjt (self ) -> None :
116
+ values = torch .Tensor ([1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 , 7.0 , 8.0 ])
117
+ weights = torch .Tensor ([1.0 , 0.5 , 1.5 , 1.0 , 0.5 , 1.0 , 1.0 , 1.5 ])
118
+ keys = ["index_0" , "index_1" ]
119
+ offsets = torch .IntTensor ([0 , 2 , 2 , 3 , 4 , 5 , 8 ])
120
+ stride_per_key_per_rank = [[2 ], [4 ]]
121
+ inverse_indices = torch .tensor ([[0 , 1 , 0 ], [0 , 0 , 0 ]])
122
+
123
+ kjt_0 = KeyedJaggedTensor (
118
124
values = values ,
119
- lengths = torch .IntTensor ([1 , 0 , 2 , 3 ]),
125
+ keys = keys ,
126
+ offsets = offsets ,
127
+ weights = weights ,
128
+ stride_per_key_per_rank = stride_per_key_per_rank ,
129
+ inverse_indices = (keys , inverse_indices ),
120
130
)
121
- elems , spec = pytree .tree_flatten (j0 )
122
- j1 = pytree .tree_unflatten (elems , spec )
131
+ elems , spec = pytree .tree_flatten (kjt_0 )
132
+ kjt_1 = pytree .tree_unflatten (elems , spec )
123
133
124
- self .assertTrue (torch .equal (j0 .lengths (), j1 .lengths ()))
125
- self .assertIsNone (j0 .weights_or_none ())
126
- self .assertIsNone (j1 .weights_or_none ())
127
- self .assertTrue (torch .equal (j0 .values (), j1 .values ()))
134
+ self .assertTrue (torch .equal (kjt_0 .values (), kjt_1 .values ()))
135
+ self .assertIsNone (kjt_0 .lengths_or_none ())
136
+ self .assertIsNone (kjt_1 .lengths_or_none ())
137
+ self .assertTrue (torch .equal (kjt_0 .weights (), kjt_1 .weights ()))
138
+ self .assertTrue (torch .equal (kjt_0 .offsets (), kjt_1 .offsets ()))
139
+ self .assertEqual (kjt_0 .keys (), kjt_1 .keys ())
140
+ self .assertEqual (
141
+ kjt_0 .stride_per_key_per_rank (), kjt_1 .stride_per_key_per_rank ()
142
+ )
143
+ self .assertEqual (kjt_0 .inverse_indices ()[0 ], kjt_1 .inverse_indices ()[0 ])
144
+ self .assertTrue (
145
+ torch .equal (kjt_0 .inverse_indices ()[1 ], kjt_1 .inverse_indices ()[1 ])
146
+ )
128
147
129
- values = [
130
- torch .Tensor ([1.0 ]),
131
- torch .Tensor (),
132
- torch .Tensor ([7.0 , 8.0 ]),
133
- torch .Tensor ([10.0 , 11.0 , 12.0 ]),
134
- ]
135
- weights = [
136
- torch .Tensor ([1.0 ]),
137
- torch .Tensor (),
138
- torch .Tensor ([7.0 , 8.0 ]),
139
- torch .Tensor ([10.0 , 11.0 , 12.0 ]),
140
- ]
141
- j0 = JaggedTensor .from_dense (
148
+ kjt_0 = KeyedJaggedTensor (
142
149
values = values ,
150
+ keys = keys ,
151
+ offsets = offsets ,
143
152
weights = weights ,
144
153
)
145
- elems , spec = pytree .tree_flatten (j0 )
146
- j1 = pytree .tree_unflatten (elems , spec )
147
-
148
- self .assertTrue (torch .equal (j0 .lengths (), j1 .lengths ()))
149
- self .assertTrue (torch .equal (j0 .weights (), j1 .weights ()))
150
- self .assertTrue (torch .equal (j0 .values (), j1 .values ()))
154
+ elems , spec = pytree .tree_flatten (kjt_0 )
155
+
156
+ # Simulate missing stride_per_key_per_rank and inverse_indices
157
+ spec = pytree .TreeSpec (
158
+ type = spec .type ,
159
+ context = spec .context ,
160
+ children_specs = spec .children_specs [:4 ],
161
+ )
162
+ kjt_1 = pytree .tree_unflatten (elems [:4 ], spec )
163
+
164
+ self .assertTrue (torch .equal (kjt_0 .values (), kjt_1 .values ()))
165
+ self .assertIsNone (kjt_0 .lengths_or_none ())
166
+ self .assertIsNone (kjt_1 .lengths_or_none ())
167
+ self .assertTrue (torch .equal (kjt_0 .weights (), kjt_1 .weights ()))
168
+ self .assertTrue (torch .equal (kjt_0 .offsets (), kjt_1 .offsets ()))
169
+ self .assertEqual (kjt_0 .keys (), kjt_1 .keys ())
170
+ self .assertTrue (len (kjt_0 .stride_per_key_per_rank ()) == 0 )
171
+ self .assertTrue (len (kjt_1 .stride_per_key_per_rank ()) == 0 )
172
+ self .assertIsNone (kjt_0 .inverse_indices_or_none ())
173
+ self .assertIsNone (kjt_1 .inverse_indices_or_none ())
151
174
152
175
def test_to_dict_vb (self ) -> None :
153
176
values = torch .Tensor ([1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 , 7.0 , 8.0 ])
0 commit comments