Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions ext/GradedArraysTensorAlgebraExt/GradedArraysTensorAlgebraExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,55 +21,56 @@

struct SectorFusion <: FusionStyle end

TensorAlgebra.FusionStyle(::Type{<:GradedArray}) = SectorFusion()

Check warning on line 24 in ext/GradedArraysTensorAlgebraExt/GradedArraysTensorAlgebraExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/GradedArraysTensorAlgebraExt/GradedArraysTensorAlgebraExt.jl#L24

Added line #L24 was not covered by tests

# TODO consider heterogeneous sectors?
# TBD consider heterogeneous sectors?
TensorAlgebra.trivial_axis(t::Tuple{Vararg{AbstractGradedUnitRange}}) = trivial(first(t))

Check warning on line 27 in ext/GradedArraysTensorAlgebraExt/GradedArraysTensorAlgebraExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/GradedArraysTensorAlgebraExt/GradedArraysTensorAlgebraExt.jl#L27

Added line #L27 was not covered by tests

function matricize_axes(

Check warning on line 29 in ext/GradedArraysTensorAlgebraExt/GradedArraysTensorAlgebraExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/GradedArraysTensorAlgebraExt/GradedArraysTensorAlgebraExt.jl#L29

Added line #L29 was not covered by tests
blocked_axes::BlockedTuple{2,<:Any,<:Tuple{Vararg{AbstractUnitRange}}}
)
@assert !isempty(blocked_axes)
default_axis = trivial_axis(Tuple(blocked_axes))
codomain_axes, domain_axes = blocks(blocked_axes)
@assert !(isempty(codomain_axes) && isempty(domain_axes))
row_axis = unmerged_tensor_product(trivial_axes(domain_axes), codomain_axes...)
unflipped_col_axis = unmerged_tensor_product(trivial_axes(codomain_axes), domain_axes...)
row_axis = unmerged_tensor_product(default_axis, codomain_axes...)
unflipped_col_axis = unmerged_tensor_product(default_axis, domain_axes...)
return row_axis, flip(unflipped_col_axis)

Check warning on line 37 in ext/GradedArraysTensorAlgebraExt/GradedArraysTensorAlgebraExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/GradedArraysTensorAlgebraExt/GradedArraysTensorAlgebraExt.jl#L32-L37

Added lines #L32 - L37 were not covered by tests
end

function TensorAlgebra.matricize(

Check warning on line 40 in ext/GradedArraysTensorAlgebraExt/GradedArraysTensorAlgebraExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/GradedArraysTensorAlgebraExt/GradedArraysTensorAlgebraExt.jl#L40

Added line #L40 was not covered by tests
::SectorFusion, a::AbstractArray, biperm::AbstractBlockPermutation{2}
)
a_perm = permutedims(a, Tuple(biperm))
row_axis, col_axis = matricize_axes(axes(a)[biperm])
a_reshaped = blockreshape(a_perm, (row_axis, col_axis))

Check warning on line 45 in ext/GradedArraysTensorAlgebraExt/GradedArraysTensorAlgebraExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/GradedArraysTensorAlgebraExt/GradedArraysTensorAlgebraExt.jl#L43-L45

Added lines #L43 - L45 were not covered by tests
# Sort the blocks by sector and merge the equivalent sectors.
return sectormergesort(a_reshaped)

Check warning on line 47 in ext/GradedArraysTensorAlgebraExt/GradedArraysTensorAlgebraExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/GradedArraysTensorAlgebraExt/GradedArraysTensorAlgebraExt.jl#L47

Added line #L47 was not covered by tests
end

function TensorAlgebra.unmatricize(

Check warning on line 50 in ext/GradedArraysTensorAlgebraExt/GradedArraysTensorAlgebraExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/GradedArraysTensorAlgebraExt/GradedArraysTensorAlgebraExt.jl#L50

Added line #L50 was not covered by tests
::SectorFusion,
m::AbstractMatrix,
blocked_axes::BlockedTuple{2,<:Any,<:Tuple{Vararg{AbstractUnitRange}}},
)
# First, fuse axes to get `blockmergesortperm`.
# Then unpermute the blocks.
row_col_axes = row_and_column_axes(blocked_axes)
row_col_axes = matricize_axes(blocked_axes)

Check warning on line 57 in ext/GradedArraysTensorAlgebraExt/GradedArraysTensorAlgebraExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/GradedArraysTensorAlgebraExt/GradedArraysTensorAlgebraExt.jl#L57

Added line #L57 was not covered by tests

blockperms = blocksortperm.(row_col_axes)
sorted_axes = map((r, I) -> only(axes(r[I])), row_col_axes, blockperms)

Check warning on line 60 in ext/GradedArraysTensorAlgebraExt/GradedArraysTensorAlgebraExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/GradedArraysTensorAlgebraExt/GradedArraysTensorAlgebraExt.jl#L59-L60

Added lines #L59 - L60 were not covered by tests

# TODO: This is doing extra copies of the blocks,
# use `@view a[axes_prod...]` instead.
# That will require implementing some reindexing logic
# for this combination of slicing.
m_unblocked = m[sorted_axes...]
m_blockpermed = m_unblocked[invblockperm.(blockperms)...]
return unmatricize(FusionStyle(BlockSparseArray), m_blockpermed, blocked_axes)

Check warning on line 68 in ext/GradedArraysTensorAlgebraExt/GradedArraysTensorAlgebraExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/GradedArraysTensorAlgebraExt/GradedArraysTensorAlgebraExt.jl#L66-L68

Added lines #L66 - L68 were not covered by tests
end

# Sort the blocks by sector and then merge the common sectors.
function sectormergesort(a::AbstractArray)
I = blockmergesortperm.(axes(a))
return a[I...]

Check warning on line 74 in ext/GradedArraysTensorAlgebraExt/GradedArraysTensorAlgebraExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/GradedArraysTensorAlgebraExt/GradedArraysTensorAlgebraExt.jl#L72-L74

Added lines #L72 - L74 were not covered by tests
end
end
Loading