Skip to content

Commit c6368c9

Browse files
authored
Upgrade luxury sparse (#555)
* Fix tests * update Project.toml
1 parent 222d1cf commit c6368c9

File tree

12 files changed

+22
-17
lines changed

12 files changed

+22
-17
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ CuYao = "CUDA"
2424
BitBasis = "0.8, 0.9"
2525
CUDA = "4, 5"
2626
LinearAlgebra = "1"
27-
LuxurySparse = "0.7"
27+
LuxurySparse = "0.8"
2828
Reexport = "1"
2929
YaoAPI = "0.4"
3030
YaoArrayRegister = "0.9"

lib/YaoArrayRegister/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ Adapt = "3, 4"
2222
BitBasis = "0.8, 0.9"
2323
DocStringExtensions = "0.8, 0.9"
2424
LegibleLambdas = "0.3"
25-
LuxurySparse = "0.7"
25+
LuxurySparse = "0.8"
2626
MLStyle = "0.4"
2727
StaticArrays = "1"
2828
StatsBase = "0.33 - 0.34"

lib/YaoArrayRegister/src/YaoArrayRegister.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ using BitBasis
1212
using LinearAlgebra
1313
using LegibleLambdas
1414
using StatsBase, Random
15-
using LuxurySparse, StaticArrays
15+
using StaticArrays
16+
import LuxurySparse: fastkron, IMatrix, PermMatrix, Diagonal, SparseMatrixCSC
1617
using DocStringExtensions
1718

1819
export AbstractArrayReg,

lib/YaoArrayRegister/src/utils.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -441,12 +441,11 @@ function _wrap_identity(
441441
nlevel
442442
) where {T<:AbstractMatrix}
443443
length(num_bit_list) == length(data_list) + 1 || throw(ArgumentError())
444-
= kron
445444
reduce(
446445
zip(data_list, num_bit_list[2:end]);
447446
init = IMatrix(nlevel ^ num_bit_list[1]),
448447
) do x, y
449-
x y[1] IMatrix(nlevel ^ y[2])
448+
fastkron(fastkron(x, y[1]), IMatrix(nlevel ^ y[2]))
450449
end
451450
end
452451

lib/YaoBlocks/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ InteractiveUtils = "1"
3030
KrylovKit = "0.5, 0.6, 0.7, 0.8, 0.9"
3131
LegibleLambdas = "0.2, 0.3"
3232
LinearAlgebra = "1"
33-
LuxurySparse = "0.7"
33+
LuxurySparse = "0.8"
3434
MLStyle = "0.3, 0.4"
3535
Random = "1"
3636
SparseArrays = "1"

lib/YaoBlocks/src/YaoBlocks.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@ using YaoAPI
99
using LinearAlgebra
1010
using YaoArrayRegister
1111
using YaoArrayRegister: , matvec, diff, autostatic, rot_mat
12-
using BitBasis, LuxurySparse
12+
using BitBasis, LuxurySparse, SparseArrays
13+
using LuxurySparse: fastkron
1314
using StatsBase, TupleTools, InteractiveUtils
1415
using MLStyle: @match
1516
using LinearAlgebra: eigen!
16-
using SparseArrays, LuxurySparse
1717
using Random, CacheServers
1818
import KrylovKit: exponentiate
1919
using DocStringExtensions

lib/YaoBlocks/src/autodiff/outerproduct_and_projection.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,7 @@ unsafe_projection!(y::Diagonal, op::OuterProduct) = (y.diag .+= op.left .* op.ri
8080
unsafe_projection!(y::Matrix, adjy, v) = y .+= adjy .* v
8181

8282
@inline function unsafe_projection!(y::AbstractSparseMatrix, m::AbstractMatrix)
83-
is, js, vs = findnz(y)
84-
for (k, (i, j)) in enumerate(zip(is, js))
83+
for (k, (i, j, _)) in enumerate(IterNz(y))
8584
@inbounds y.nzval[k] += m[i, j]
8685
end
8786
y

lib/YaoBlocks/src/composite/kron.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ function mat(::Type{T}, k::KronBlock{D,M}) where {T,D,M}
172172
Iterators.reverse(zip(subblocks(k), num_bit_list)),
173173
init = IMatrix{T}(D^ntrail),
174174
) do x, y
175-
kron(x, mat(T, y[1]), IMatrix(D^y[2]))
175+
fastkron(fastkron(x, mat(T, y[1])), IMatrix(D^y[2]))
176176
end
177177
end
178178

lib/YaoBlocks/src/composite/tag/onlevels.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,15 @@ end
2020
content(g::OnLevels) = g.gate
2121
function mat(::Type{T}, g::OnLevels{D, Ds}) where {T, D, Ds}
2222
m = mat(T, g.gate)
23-
I, J, V = LuxurySparse.findnz(m)
24-
return sparse(collect(g.levels[I]), collect(g.levels[J]), V, D, D)
23+
iter = IterNz(m)
24+
nnz = length(iter)
25+
is, js, vs = Vector{Int}(undef, nnz), Vector{Int}(undef, nnz), Vector{T}(undef, nnz)
26+
for (k, (i, j, v)) in enumerate(iter)
27+
is[k] = g.levels[i]
28+
js[k] = g.levels[j]
29+
vs[k] = v
30+
end
31+
return sparse(is, js, vs, D, D)
2532
end
2633
PropertyTrait(::OnLevels) = PreserveAll()
2734
Base.adjoint(x::OnLevels{D}) where D = OnLevels{D}(adjoint(x.gate), x.levels)

lib/YaoBlocks/src/outerproduct_and_projection.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,7 @@ unsafe_projection!(y::Diagonal, op::OuterProduct) = (y.diag .+= op.left .* op.ri
8080
unsafe_projection!(y::Matrix, adjy, v) = y .+= adjy .* v
8181

8282
@inline function unsafe_projection!(y::AbstractSparseMatrix, m::AbstractMatrix)
83-
is, js, vs = findnz(y)
84-
for (k, (i, j)) in enumerate(zip(is, js))
83+
for (k, (i, j, _)) in enumerate(IterNz(y))
8584
@inbounds y.nzval[k] += m[i, j]
8685
end
8786
y

0 commit comments

Comments
 (0)