-
Notifications
You must be signed in to change notification settings - Fork 21
Open
Labels
enhancementNew feature or requestNew feature or request
Description
By replacing explicite tensor operations with torch.einsum() in the Zero-Order-Hold transformation, performance and readability can be improved.
Replacing the original Zero-Order-Hold transformation in line 518 of mamba_arch.py
deltaA = torch.exp(delta.unsqueeze(-1) * A)
deltaB = delta.unsqueeze(-1) * B.unsqueeze(2)
BX = deltaB * (x.unsqueeze(-1))with:
deltaA = torch.einsum('bld,dn->bldn', dt, A)
BX = torch.einsum('bld,bld,bln->bldn', dt, u, B) can improve execution time by up to ~40% while requiring the same number of FLOPS. (See attached plot)
Moreover, vectorization of the loop does not further improve execution time.
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request
