- 
          
 - 
                Notifications
    
You must be signed in to change notification settings  - Fork 129
 
Open
Description
Motivation and description
There exists a method for batched_mul that reshapes arrays to allow for an arbitrary number of batch dimensions:
function batched_mul(x::AbstractArray{T1,N}, y::AbstractArray{T2,N}) where {T1,T2,N}
    batch_size = size(x)[3:end]
    @assert batch_size == size(y)[3:end] "batch size has to be the same for the two arrays."
    x2 = reshape(x, size(x, 1), size(x, 2), :)
    y2 = reshape(y, size(y, 1), size(y, 2), :)
    z = batched_mul(x2, y2)
    return reshape(z, size(z, 1), size(z, 2), batch_size...)
  endIt would be useful to have support for this with batched_transpose and batched_adjoint as well.
Possible Implementation
The existing code is quite sophisticated and "lazy", so something like this wouldn't fly:
batched_transpose(A::AbstractArray{T, N}) where {T <: Real, N} = permutedims(A, (2, 1, 3:N))I imagine it would be possible to generalize the code beyond three dimensions though. Indexing methods are currently hard-coded. Things like the strides would also need to be generalized:
function Base.strides(A::Union{BatchedTranspose, BatchedAdjoint{<:Real}})
    sp = strides(A.parent)
    (sp[2], sp[1], sp[3:end]...)
endIs it better to just use PermutedDimsArray?
Metadata
Metadata
Assignees
Labels
No labels