Skip to content

Commit 6da5d4f

Browse files
committed
Fix more tests
1 parent 0856230 commit 6da5d4f

File tree

3 files changed

+47
-7
lines changed

3 files changed

+47
-7
lines changed

ext/BlockSparseArraysGradedUnitRangesExt/BlockSparseArraysGradedUnitRangesExt.jl

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
module BlockSparseArraysGradedUnitRangesExt
22

3-
using BlockSparseArrays: BlockSparseArray
3+
using BlockSparseArrays: AnyAbstractBlockSparseArray, BlockSparseArray, blocktype
44
using GradedUnitRanges: AbstractGradedUnitRange
5+
using TypeParameterAccessors: set_ndims, unwrap_array_type
56

67
# A block spare array similar to the input (dense) array.
78
# TODO: Make `BlockSparseArrays.blocksparse_similar` more general and use that,
@@ -13,7 +14,11 @@ function similar_blocksparse(
1314
)
1415
# TODO: Probably need to unwrap the type of `a` in certain cases
1516
# to make a proper block type.
16-
return BlockSparseArray{elt,length(axes),typeof(a)}(axes)
17+
return BlockSparseArray{
18+
elt,length(axes),set_ndims(unwrap_array_type(blocktype(a)), length(axes))
19+
}(
20+
axes
21+
)
1722
end
1823

1924
function Base.similar(
@@ -35,12 +40,43 @@ function Base.similar(
3540
return similar_blocksparse(a, elt, axes)
3641
end
3742

38-
function Base.getindex(
39-
a::AbstractArray, I1::AbstractGradedUnitRange, I_rest::AbstractGradedUnitRange...
43+
# Fix ambiguity error with `BlockSparseArrays.jl`.
44+
function Base.similar(
45+
a::AnyAbstractBlockSparseArray,
46+
elt::Type,
47+
axes::Tuple{
48+
AbstractGradedUnitRange,AbstractGradedUnitRange,Vararg{AbstractGradedUnitRange}
49+
},
4050
)
51+
return similar_blocksparse(a, elt, axes)
52+
end
53+
54+
function getindex_blocksparse(a::AbstractArray, I::AbstractUnitRange...)
4155
a′ = similar(a, only.(axes.(I))...)
4256
a′ .= a
4357
return a′
4458
end
4559

60+
function Base.getindex(
61+
a::AbstractArray, I1::AbstractGradedUnitRange, I_rest::AbstractGradedUnitRange...
62+
)
63+
return getindex_blocksparse(a, I1, I_rest...)
64+
end
65+
66+
# Fix ambiguity errors.
67+
function Base.getindex(
68+
a::AnyAbstractBlockSparseArray,
69+
I1::AbstractGradedUnitRange,
70+
I_rest::AbstractGradedUnitRange...,
71+
)
72+
return getindex_blocksparse(a, I1, I_rest...)
73+
end
74+
function Base.getindex(
75+
a::AnyAbstractBlockSparseArray{<:Any,2},
76+
I1::AbstractGradedUnitRange,
77+
I2::AbstractGradedUnitRange,
78+
)
79+
return getindex_blocksparse(a, I1, I2)
80+
end
81+
4682
end
Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
using DerivableInterfaces: @interface, interface
22

3-
# TODO: Define with `@derive`.
4-
function Base.cat(as::AnyAbstractBlockSparseArray...; dims)
3+
function Base._cat(dims, as::AnyAbstractBlockSparseArray...)
4+
# TODO: Call `DerivableInterfaces.cat_along(dims, as...)` instead,
5+
# for better inferability. See:
6+
# https://github.com/ITensor/DerivableInterfaces.jl/pull/13
7+
# https://github.com/ITensor/DerivableInterfaces.jl/pull/17
58
return @interface interface(as...) cat(as...; dims)
69
end

src/abstractblocksparsearray/map.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@ end
7171
)
7272
if isempty(a_srcs)
7373
# Broadcast expressions of the form `a .= 2`.
74-
error("Not implemented.")
74+
@interface interface fill!(a_dest, f())
75+
return a_dest
7576
end
7677
if iszero(ndims(a_dest))
7778
@interface interface map_zero_dim!(f, a_dest, a_srcs...)

0 commit comments

Comments
 (0)