1+ using Adapt: adapt
12using ArrayLayouts: zero!
23using BlockArrays:
34 Block,
@@ -28,22 +29,18 @@ using BlockSparseArrays:
2829 blocktype,
2930 view!
3031using GPUArraysCore: @allowscalar
32+ using JLArrays: JLArray
3133using LinearAlgebra: Adjoint, Transpose, dot, mul!, norm
32- using NDTensors. GPUArraysCoreExtensions: cpu
3334using SparseArraysBase: SparseArrayDOK, SparseMatrixDOK, SparseVectorDOK, storedlength
3435using TensorAlgebra: contract
3536using Test: @test , @test_broken , @test_throws , @testset , @inferred
3637include (" TestBlockSparseArraysUtils.jl" )
3738
38- using NDTensors: NDTensors
39- include (joinpath (pkgdir (NDTensors), " test" , " NDTensorsTestUtils" , " NDTensorsTestUtils.jl" ))
40- using . NDTensorsTestUtils: devices_list, is_supported_eltype
41- @testset " BlockSparseArrays (dev=$dev , eltype=$elt )" for dev in devices_list (copy (ARGS )),
39+ arrayts = (Array, JLArray)
40+ @testset " BlockSparseArrays (arraytype=$arrayt , eltype=$elt )" for arrayt in arrayts,
4241 elt in (Float32, Float64, Complex{Float32}, Complex{Float64})
4342
44- if ! is_supported_eltype (dev, elt)
45- continue
46- end
43+ dev (a) = adapt (arrayt, a)
4744 @testset " Broken" begin
4845 # TODO : Fix this and turn it into a proper test.
4946 a = dev (BlockSparseArray {elt} ([2 , 3 ], [2 , 3 ]))
@@ -268,7 +265,7 @@ using .NDTensorsTestUtils: devices_list, is_supported_eltype
268265 @test storedlength (a) == 2 * 4 + 3 * 3
269266
270267 # TODO : Broken on GPU.
271- if dev ≠ cpu
268+ if arrayt ≠ Array
272269 a = dev (BlockSparseArray {elt} ([2 , 3 ], [3 , 4 ]))
273270 @test_broken a[Block (1 , 2 )] .= 2
274271 end
@@ -285,7 +282,7 @@ using .NDTensorsTestUtils: devices_list, is_supported_eltype
285282 @test storedlength (a) == 2 * 4
286283
287284 # TODO : Broken on GPU.
288- if dev ≠ cpu
285+ if arrayt ≠ Array
289286 a = dev (BlockSparseArray {elt} ([2 , 3 ], [3 , 4 ]))
290287 @test_broken a[Block (1 , 2 )] .= 0
291288 end
@@ -321,15 +318,15 @@ using .NDTensorsTestUtils: devices_list, is_supported_eltype
321318 @test iszero (b)
322319 @test iszero (storedlength (b))
323320 # TODO : Broken on GPU.
324- @test iszero (c) broken = dev ≠ cpu
321+ @test iszero (c) broken = arrayt ≠ Array
325322 @test iszero (storedlength (c))
326323 @allowscalar a[5 , 7 ] = 1
327324 @test ! iszero (a)
328325 @test storedlength (a) == 3 * 4
329326 @test ! iszero (b)
330327 @test storedlength (b) == 3 * 4
331328 # TODO : Broken on GPU.
332- @test ! iszero (c) broken = dev ≠ cpu
329+ @test ! iszero (c) broken = arrayt ≠ Array
333330 @test storedlength (c) == 3 * 4
334331 d = @view a[1 : 4 , 1 : 6 ]
335332 @test iszero (d)
0 commit comments