diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 78cdd2e1f..0c40f1b31 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -746,6 +746,53 @@ def to_dense(self) -> List[torch.Tensor]: tensor_list.append(self.values()[offset:next_offset]) return tensor_list + def to_dense_stacked(self) -> torch.Tensor: + """ + Optimized JaggedTensor to dense conversion that provides better performance than to_padded_dense(). + + Performance optimizations: + 1. Length=1 sequences: Zero DtoH transfers (simple reshape) + 2. Uniform lengths: Eliminates max_length computation DtoH transfer + 3. Variable lengths: Uses to_padded_dense() directly without reconstruction overhead + + Returns: + torch.Tensor: Stacked dense tensor equivalent to torch.vstack(jt.to_dense()) + """ + lengths = self.lengths() + values = self.values() + + # ==================== ULTRA-FAST PATH: Length=1 sequences ============ + # This is the most common case in InTrainerSeqStore - all embeddings have length 1 + if torch.all(lengths == 1): + # Zero DtoH transfers - pure GPU reshape operation + batch_size = lengths.size(0) + if values.dim() == 1: + return values.view(batch_size, 1) + else: + feature_dim = values.size(-1) if values.dim() > 1 else 1 + return values.view(batch_size, feature_dim) + + # ==================== FAST PATH: Uniform lengths (non-1) ==================== + # Check if all sequences have the same length but not 1 + first_length = lengths[0] + if torch.all(lengths == first_length): + # All uniform lengths - we can avoid the expensive max_length computation! + # This is faster than to_padded_dense() because we skip torch.max().item() call + seq_length = ( + first_length.item() + ) # Only ONE .item() call instead of torch.max().item() + + # Use efficient fbgemm call with known exact length (no over-padding) + offsets = self.offsets() + return torch.ops.fbgemm.jagged_to_padded_dense( + values, [offsets], [seq_length], 0.0 + ) # No trimming needed since we used exact length + + # ==================== VARIABLE LENGTH PATH ==================== + # For true variable lengths: just delegate to to_padded_dense() + # This is equivalent performance but with the fast paths above for common cases + return self.to_padded_dense() + def to_dense_weights(self) -> Optional[List[torch.Tensor]]: """ Constructs a dense-representation of the JT's weights. diff --git a/torchrec/sparse/tests/test_jagged_tensor.py b/torchrec/sparse/tests/test_jagged_tensor.py index 09a7e6b5f..f13a693cd 100644 --- a/torchrec/sparse/tests/test_jagged_tensor.py +++ b/torchrec/sparse/tests/test_jagged_tensor.py @@ -396,6 +396,100 @@ def test_to_padded_dense(self) -> None: expected_t2 = torch.tensor(t2_value).type(torch.int64) self.assertTrue(torch.equal(t2, expected_t2)) + def test_to_dense_stacked(self) -> None: + """Test to_dense_stacked method for various cases.""" + # Test case 1: All sequences have length 1 (ultra-fast path) + values = torch.tensor([1.0, 2.0, 3.0, 4.0]) + lengths = torch.tensor([1, 1, 1, 1], dtype=torch.int32) + jt = JaggedTensor(values=values, lengths=lengths) + + dense_stacked = jt.to_dense_stacked() + expected = torch.tensor([[1.0], [2.0], [3.0], [4.0]]) + self.assertTrue(torch.equal(dense_stacked, expected)) + + # Test with 2D values - length 1 sequences + values_2d = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + lengths = torch.tensor([1, 1, 1], dtype=torch.int32) + jt_2d = JaggedTensor(values=values_2d, lengths=lengths) + + dense_stacked_2d = jt_2d.to_dense_stacked() + expected_2d = values_2d.view(3, 2) # Should match original shape + self.assertTrue(torch.equal(dense_stacked_2d, expected_2d)) + + # Test case 2: All sequences have same uniform length (fast path) + values = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + lengths = torch.tensor([2, 2, 2], dtype=torch.int32) + jt = JaggedTensor(values=values, lengths=lengths) + + dense_stacked = jt.to_dense_stacked() + # Compare with to_padded_dense to ensure consistency + expected_padded = jt.to_padded_dense() + self.assertTrue(torch.equal(dense_stacked, expected_padded)) + + # Test case 3: Variable lengths (fallback path) + values = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) + lengths = torch.tensor([2, 0, 3, 1, 2], dtype=torch.int32) + jt = JaggedTensor(values=values, lengths=lengths) + + dense_stacked = jt.to_dense_stacked() + # Should be equivalent to to_padded_dense for variable lengths + expected_padded = jt.to_padded_dense() + self.assertTrue(torch.equal(dense_stacked, expected_padded)) + + # Test case 4: Empty tensor + empty_values = torch.tensor([], dtype=torch.float32) + empty_lengths = torch.tensor([], dtype=torch.int32) + jt_empty = JaggedTensor(values=empty_values, lengths=empty_lengths) + + dense_stacked_empty = jt_empty.to_dense_stacked() + # Empty case should result in empty tensor with correct shape + self.assertEqual(dense_stacked_empty.numel(), 0) + + # Test case 5: Single batch with multiple elements + values = torch.tensor([10.0, 20.0, 30.0]) + lengths = torch.tensor([3], dtype=torch.int32) + jt_single = JaggedTensor(values=values, lengths=lengths) + + dense_stacked_single = jt_single.to_dense_stacked() + expected_single = torch.tensor([[10.0, 20.0, 30.0]]) + self.assertTrue(torch.equal(dense_stacked_single, expected_single)) + + # Test case 6: Mix of empty and non-empty sequences (variable length) + values = torch.tensor([1.0, 2.0, 3.0]) + lengths = torch.tensor([2, 0, 1], dtype=torch.int32) + jt_mixed = JaggedTensor(values=values, lengths=lengths) + + dense_stacked_mixed = jt_mixed.to_dense_stacked() + expected_mixed = jt_mixed.to_padded_dense() + self.assertTrue(torch.equal(dense_stacked_mixed, expected_mixed)) + + # Test case 7: Performance comparison - ensure to_dense_stacked matches to_padded_dense + # for correctness on various shapes + values = torch.randn(100) + lengths = torch.randint(1, 10, (20,), dtype=torch.int32) + # Adjust lengths to match values size + total_needed = values.size(0) + current_sum = int(lengths.sum().item()) + if current_sum != total_needed: + # Adjust the last length to make the sum correct + lengths[-1] = lengths[-1] + (total_needed - current_sum) + + jt_large = JaggedTensor(values=values, lengths=lengths) + + dense_stacked_large = jt_large.to_dense_stacked() + expected_padded_large = jt_large.to_padded_dense() + self.assertTrue(torch.equal(dense_stacked_large, expected_padded_large)) + + # Test case 8: Different data types + values_int = torch.tensor([1, 2, 3, 4, 5, 6], dtype=torch.int64) + lengths_int = torch.tensor([2, 2, 2], dtype=torch.int32) + jt_int = JaggedTensor(values=values_int, lengths=lengths_int) + + dense_stacked_int = jt_int.to_dense_stacked() + self.assertEqual(dense_stacked_int.dtype, torch.int64) + expected_int = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.int64) + self.assertTrue(torch.equal(dense_stacked_int, expected_int)) + def test_to_padded_dense_weights(self) -> None: values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).type( torch.float64