Skip to content

Commit 8e4857e

Browse files
committed
Remove LazyArrays.jl dependency
- Removed LazyArrays from Project.toml dependencies - Replaced BroadcastArray(@~ .-(λ .* tu)) with simple broadcast .-(λ .* tu) in adjoint.jl - Updated tests to check for AbstractMatrix instead of BroadcastArray - All tests pass successfully
1 parent 18d53c1 commit 8e4857e

File tree

5 files changed

+10
-9
lines changed

5 files changed

+10
-9
lines changed

Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
1212
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1313
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
1414
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
15-
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
1615
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
1716
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1817
MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
@@ -105,7 +104,6 @@ Krylov = "0.10"
105104
KrylovKit = "0.10"
106105
KrylovPreconditioners = "0.3"
107106
LAPACK_jll = "3"
108-
LazyArrays = "2.3"
109107
Libdl = "1.10"
110108
LinearAlgebra = "1.10"
111109
MPI = "0.20"

src/LinearSolve.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ using LinearAlgebra: LinearAlgebra, BlasInt, LU, Adjoint, BLAS, Bidiagonal, Bunc
1515
cholesky, cholesky!, diagind, dot, inv, ldiv!, ldlt!, lu, lu!, mul!,
1616
norm,
1717
qr, qr!, svd, svd!
18-
using LazyArrays: @~, BroadcastArray
1918
using SciMLBase: SciMLBase, LinearAliasSpecifier, AbstractSciMLOperator,
2019
init, solve!, reinit!, solve, ReturnCode, LinearProblem
2120
using SciMLOperators: SciMLOperators, AbstractSciMLOperator, IdentityOperator,

src/adjoint.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ function CRC.rrule(::typeof(SciMLBase.solve), prob::LinearProblem,
8484
end
8585

8686
tu = adjoint(sol.u)
87-
∂A = BroadcastArray(@~ .-.* tu))
87+
∂A = .-.* tu)
8888
∂b = λ
8989
∂prob = LinearProblem(∂A, ∂b, ∂∅)
9090

test/Project.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
[deps]
2+
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
3+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
4+
RecursiveFactorization = "f2c3362d-daeb-58d1-803e-2bc74f2840b4"
5+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

test/adjoint.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
using Zygote, ForwardDiff
22
using LinearSolve, LinearAlgebra, Test
33
using FiniteDiff, RecursiveFactorization
4-
using LazyArrays: BroadcastArray
54

65
n = 4
76
A = rand(n, n);
@@ -19,7 +18,7 @@ end
1918
f(A, b1) # Uses BLAS
2019

2120
dA, db1 = Zygote.gradient(f, A, b1)
22-
@test dA isa BroadcastArray
21+
@test dA isa AbstractMatrix
2322

2423
dA2 = ForwardDiff.gradient(x -> f(x, eltype(x).(b1)), copy(A))
2524
db12 = ForwardDiff.gradient(x -> f(eltype(x).(A), x), copy(b1))
@@ -37,7 +36,7 @@ _ff = (x,
3736
_ff(copy(A), copy(b1))
3837

3938
dA, db1 = Zygote.gradient(_ff, copy(A), copy(b1))
40-
@test dA isa BroadcastArray
39+
@test dA isa AbstractMatrix
4140

4241
dA2 = ForwardDiff.gradient(x -> f(x, eltype(x).(b1)), copy(A))
4342
db12 = ForwardDiff.gradient(x -> f(eltype(x).(A), x), copy(b1))
@@ -58,7 +57,7 @@ function f3(A, b1, b2; alg = KrylovJL_GMRES())
5857
end
5958

6059
dA, db1, db2 = Zygote.gradient(f3, A, b1, b1)
61-
@test dA isa BroadcastArray
60+
@test dA isa AbstractMatrix
6261

6362
dA2 = FiniteDiff.finite_difference_gradient(
6463
x -> f3(x, eltype(x).(b1), eltype(x).(b1)), copy(A))
@@ -83,7 +82,7 @@ function f4(A, b1, b2; alg = LUFactorization())
8382
end
8483

8584
dA, db1, db2 = Zygote.gradient(f4, A, b1, b1)
86-
@test dA isa BroadcastArray
85+
@test dA isa AbstractMatrix
8786

8887
dA2 = ForwardDiff.gradient(x -> f4(x, eltype(x).(b1), eltype(x).(b1)), copy(A))
8988
db12 = ForwardDiff.gradient(x -> f4(eltype(x).(A), x, eltype(x).(b1)), copy(b1))

0 commit comments

Comments
 (0)