Skip to content

Commit 96cefaf

Browse files
Merge pull request #484 from mohamed82008/mt/lazy_rrule
Make the rrule's outer product lazy
2 parents c08f2e9 + 8d0fd26 commit 96cefaf

File tree

5 files changed

+12
-3
lines changed

5 files changed

+12
-3
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LinearSolve"
22
uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
33
authors = ["SciML"]
4-
version = "2.27.0"
4+
version = "3.0.0"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
@@ -14,6 +14,7 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1414
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
1515
KLU = "ef3ab10e-7fda-4108-b977-705223b18434"
1616
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
17+
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
1718
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
1819
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1920
MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
@@ -85,6 +86,7 @@ KLU = "0.6"
8586
KernelAbstractions = "0.9.16"
8687
Krylov = "0.9"
8788
KrylovKit = "0.6"
89+
LazyArrays = "1"
8890
Libdl = "1.10"
8991
LinearAlgebra = "1.10"
9092
MPI = "0.20"

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@ LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
44

55
[compat]
66
Documenter = "1"
7-
LinearSolve = "1, 2"
7+
LinearSolve = "1, 2, 3"

src/LinearSolve.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ PrecompileTools.@recompile_invalidations begin
1313
using LinearAlgebra
1414
using SparseArrays
1515
using SparseArrays: AbstractSparseMatrixCSC, nonzeros, rowvals, getcolptr
16+
using LazyArrays: @~, BroadcastArray
1617
using SciMLBase: AbstractLinearAlgorithm
1718
using SciMLOperators
1819
using SciMLOperators: AbstractSciMLOperator, IdentityOperator

src/adjoint.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ function CRC.rrule(::typeof(SciMLBase.solve), prob::LinearProblem,
7676
invprob, sensealg.linsolve; cache.abstol, cache.reltol, cache.verbose).u
7777
end
7878

79-
∂A = -λ * transpose(sol.u)
79+
tu = transpose(sol.u)
80+
∂A = BroadcastArray(@~ .-.* tu))
8081
∂b = λ
8182
∂prob = LinearProblem(∂A, ∂b, ∂∅)
8283

test/adjoint.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using Zygote, ForwardDiff
22
using LinearSolve, LinearAlgebra, Test
33
using FiniteDiff
4+
using LazyArrays: BroadcastArray
45

56
n = 4
67
A = rand(n, n);
@@ -18,6 +19,7 @@ end
1819
f(A, b1) # Uses BLAS
1920

2021
dA, db1 = Zygote.gradient(f, A, b1)
22+
@test dA isa BroadcastArray
2123

2224
dA2 = ForwardDiff.gradient(x -> f(x, eltype(x).(b1)), copy(A))
2325
db12 = ForwardDiff.gradient(x -> f(eltype(x).(A), x), copy(b1))
@@ -34,6 +36,7 @@ _ff = (x, y) -> f(x,
3436
_ff(copy(A), copy(b1))
3537

3638
dA, db1 = Zygote.gradient(_ff, copy(A), copy(b1))
39+
@test dA isa BroadcastArray
3740

3841
dA2 = ForwardDiff.gradient(x -> f(x, eltype(x).(b1)), copy(A))
3942
db12 = ForwardDiff.gradient(x -> f(eltype(x).(A), x), copy(b1))
@@ -50,6 +53,7 @@ function f3(A, b1, b2; alg = KrylovJL_GMRES())
5053
end
5154

5255
dA, db1, db2 = Zygote.gradient(f3, A, b1, b1)
56+
@test dA isa BroadcastArray
5357

5458
dA2 = FiniteDiff.finite_difference_gradient(
5559
x -> f3(x, eltype(x).(b1), eltype(x).(b1)), copy(A))
@@ -71,6 +75,7 @@ function f4(A, b1, b2; alg = LUFactorization())
7175
end
7276

7377
dA, db1, db2 = Zygote.gradient(f4, A, b1, b1)
78+
@test dA isa BroadcastArray
7479

7580
dA2 = ForwardDiff.gradient(x -> f4(x, eltype(x).(b1), eltype(x).(b1)), copy(A))
7681
db12 = ForwardDiff.gradient(x -> f4(eltype(x).(A), x, eltype(x).(b1)), copy(b1))

0 commit comments

Comments
 (0)