-
Notifications
You must be signed in to change notification settings - Fork 6
Open
Labels
bugSomething isn't workingSomething isn't working
Description
In aloception-oss, we have overloaded some operation of torch.tensor. For example, a mechanism allows torch.cat to concatenate multiple AugmentedTensor and theirs children, in a recursive manner.
But in the current state of the code: torch.cat works as expected with a List of AugmentedTensor as input, but not with a tuple of AugmentedTensor.
from aloscene import Frame
from aloscene.tensors import AugmentedTensor
x = Frame(torch.rand(3, 10, 10), names=('C', 'H', 'W'))
x.add_child('mychild',AugmentedTensor(torch.rand(2), names=("N",)) , mergeable=True, align_dim=["B", "T"])
y = Frame(torch.rand(3, 10, 10), names=('C', 'H', 'W'))
y.add_child('mychild',AugmentedTensor(torch.rand(2), names=("N",)) , mergeable=True, align_dim=["B", "T"])
result = torch.cat((x.batch(), y.batch()), dim=0)
print(result.mychild.names, " - ", result.mychild.shape)Expected output:
('B', 'N') - torch.Size([2, 2])
Current output:
('B', 'N') - torch.Size([1, 2])
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working