Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TensorAlgebra"
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.2.7"
version = "0.2.8"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
Expand Down
30 changes: 28 additions & 2 deletions src/blockedtuple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,38 @@
return AbstractBlockTupleBroadcastStyle{blocklengths(T),unspecify_type_parameters(T)}()
end

# BroadcastStyle is not called for two identical styles
# BroadcastStyle(::Style1, ::Style2) is not called when Style1 == Style2
# tuplemortar(((1,), (2,))) .== tuplemortar(((1,), (2,))) = tuplemortar(((true,), (true,)))
# tuplemortar(((1,), (2,))) .== tuplemortar(((1, 2),)) = (true, true)
# tuplemortar(((1,), (2,))) .== tuplemortar(((1,), (2,), (3,))) = error DimensionMismatch

function Base.BroadcastStyle(
::AbstractBlockTupleBroadcastStyle, ::AbstractBlockTupleBroadcastStyle
)
throw(DimensionMismatch("Incompatible blocks"))
return Base.Broadcast.Style{Tuple}()

Check warning on line 81 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L81

Added line #L81 was not covered by tests
end

# tuplemortar(((1,), (2,))) .== (1, 2) = (true, true)
function Base.BroadcastStyle(

Check warning on line 85 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L85

Added line #L85 was not covered by tests
::AbstractBlockTupleBroadcastStyle, ::Base.Broadcast.Style{Tuple}
)
return Base.Broadcast.Style{Tuple}()

Check warning on line 88 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L88

Added line #L88 was not covered by tests
end

# tuplemortar(((1,), (2,))) .== 1 = (true, false)
function Base.BroadcastStyle(

Check warning on line 92 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L92

Added line #L92 was not covered by tests
::Base.Broadcast.DefaultArrayStyle{0}, s::AbstractBlockTupleBroadcastStyle
)
return s

Check warning on line 95 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L95

Added line #L95 was not covered by tests
end

# tuplemortar(((1,), (2,))) .== [1, 1] = BlockVector([true, false], [1, 1])
function Base.BroadcastStyle(

Check warning on line 99 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L99

Added line #L99 was not covered by tests
a::Base.Broadcast.AbstractArrayStyle, ::AbstractBlockTupleBroadcastStyle
)
return a

Check warning on line 102 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L102

Added line #L102 was not covered by tests
end

function Base.copy(
bc::Broadcast.Broadcasted{AbstractBlockTupleBroadcastStyle{BlockLengths,BT}}
) where {BlockLengths,BT}
Expand Down
29 changes: 27 additions & 2 deletions test/test_blockedtuple.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Test: @test, @test_throws

using BlockArrays: Block, blocklength, blocklengths, blockedrange, blockisequal, blocks
using BlockArrays:
Block, BlockVector, blocklength, blocklengths, blockedrange, blockisequal, blocks
using TestExtras: @constinferred

using TensorAlgebra: BlockedTuple, blockeachindex, tuplemortar
Expand Down Expand Up @@ -54,12 +55,36 @@ using TensorAlgebra: BlockedTuple, blockeachindex, tuplemortar
BlockedTuple{3,blocklengths(bt)}(Tuple(bt) .+ 1)
@test (@constinferred bt .+ tuplemortar(((1,), (1, 1), (1, 1)))) ==
BlockedTuple{3,blocklengths(bt)}(Tuple(bt) .+ 1)
@test_throws DimensionMismatch bt .+ tuplemortar(((1, 1), (1, 1), (1,)))
@test bt .+ tuplemortar(((1, 1), (1, 1), (1,))) isa NTuple{5,Int}
@test bt .+ tuplemortar(((1, 1), (1, 1), (1,))) == (2, 5, 3, 6, 4)

bt = tuplemortar(((1:2, 1:2), (1:3,)))
@test length.(bt) == tuplemortar(((2, 2), (3,)))
@test length.(length.(bt)) == tuplemortar(((1, 1), (1,)))

bt = tuplemortar(((1,), (2,)))
@test (bt .== bt) isa BlockedTuple{2,(1, 1),Tuple{Bool,Bool}}
@test (bt .== bt) == tuplemortar(((true,), (true,)))
@test (bt .== tuplemortar(((1, 2),))) isa Tuple{Bool,Bool}
@test (bt .== tuplemortar(((1, 2),))) == (true, true)
@test_throws DimensionMismatch bt .== tuplemortar(((1,), (2,), (3,)))
@test (bt .== (1, 2)) isa Tuple{Bool,Bool}
@test (bt .== (1, 2)) == (true, true)
@test_throws DimensionMismatch bt .== (1, 2, 3)
@test (bt .== 1) isa BlockedTuple{2,(1, 1),Tuple{Bool,Bool}}
@test (bt .== 1) == tuplemortar(((true,), (false,)))
@test (bt .== [1, 1]) isa BlockVector{Bool}
@test blocks(bt .== [1, 1]) == [[true], [false]]
@test_throws DimensionMismatch bt .== [1, 2, 3]

@test ((1, 2) .== bt) isa Tuple{Bool,Bool}
@test ((1, 2) .== bt) == (true, true)
@test_throws DimensionMismatch (1, 2, 3) .== bt
@test (1 .== bt) isa BlockedTuple{2,(1, 1),Tuple{Bool,Bool}}
@test (1 .== bt) == tuplemortar(((true,), (false,)))
@test ([1, 1] .== bt) isa BlockVector{Bool}
@test blocks([1, 1] .== bt) == [[true], [false]]

# empty blocks
bt = tuplemortar(((1,), (), (5, 3)))
@test bt isa BlockedTuple{3}
Expand Down
Loading