Skip to content

Commit 51e7e91

Browse files
committed
Reenable more tests
1 parent 5e2fb53 commit 51e7e91

File tree

2 files changed

+18
-15
lines changed

2 files changed

+18
-15
lines changed

src/kroneckerarray.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,8 @@ end
114114
kron_nd(a::AbstractMatrix, b::AbstractMatrix) = kron(a, b)
115115
kron_nd(a::AbstractVector, b::AbstractVector) = kron(a, b)
116116

117-
Base.collect(a::KroneckerArray) = kron_nd(a.a, a.b)
117+
# Eagerly collect arguments to make more general on GPU.
118+
Base.collect(a::KroneckerArray) = kron_nd(collect(a.a), collect(a.b))
118119

119120
function Base.Array{T,N}(a::KroneckerArray{S,N}) where {T,S,N}
120121
return convert(Array{T,N}, collect(a))

test/test_blocksparsearrays.jl

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -81,17 +81,11 @@ end
8181
arrayts,
8282
elt in elts
8383

84-
if arrayt == JLArray
85-
# TODO: Collecting to `Array` is broken for GPU arrays so a lot of tests
86-
# are broken, look into fixing that.
87-
continue
88-
end
89-
9084
dev = adapt(arrayt)
9185
r = @constinferred blockrange([2 × 2, 3 × 3])
9286
d = Dict(
93-
Block(1, 1) => Eye{elt}(2, 2) randn(elt, 2, 2),
94-
Block(2, 2) => Eye{elt}(3, 3) randn(elt, 3, 3),
87+
Block(1, 1) => Eye{elt}(2, 2) dev(randn(elt, 2, 2)),
88+
Block(2, 2) => Eye{elt}(3, 3) dev(randn(elt, 3, 3)),
9589
)
9690
a = @constinferred dev(blocksparse(d, r, r))
9791
@test sprint(show, a) == sprint(show, Array(a))
@@ -133,13 +127,21 @@ end
133127

134128
@test @constinferred(norm(a)) norm(Array(a))
135129

136-
b = @constinferred exp(a)
137-
@test Array(b) exp(Array(a))
130+
if arrayt === Array
131+
b = @constinferred exp(a)
132+
@test Array(b) exp(Array(a))
133+
else
134+
@test_broken exp(a)
135+
end
138136

139-
u, s, v = svd_compact(a)
140-
@test u * s * v a
141-
@test blocktype(u) === blocktype(a)
142-
@test blocktype(v) === blocktype(a)
137+
if arrayt === Array
138+
u, s, v = svd_compact(a)
139+
@test u * s * v a
140+
@test blocktype(u) === blocktype(a)
141+
@test blocktype(v) === blocktype(a)
142+
else
143+
@test_broken svd_compact(a)
144+
end
143145

144146
# Broken operations
145147
@test_broken inv(a)

0 commit comments

Comments
 (0)