Skip to content

Commit dd6ddbc

Browse files
committed
Incremental progress
1 parent eae24bb commit dd6ddbc

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

src/pullbacks/lq.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ function lq_pullback!(
5959
ΔQ̃ = zero!(similar(Q, (p, n)))
6060
if !iszerotangent(ΔQ)
6161
ΔQ1 = view(ΔQ, 1:p, :)
62-
copy!(ΔQ̃, ΔQ1)
62+
ΔQ̃ .= ΔQ1
6363
if p < size(Q, 1)
6464
Q2 = view(Q, (p + 1):size(Q, 1), :)
6565
ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :)

src/pullbacks/qr.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ function qr_pullback!(
5959

6060
ΔQ̃ = zero!(similar(Q, (m, p)))
6161
if !iszerotangent(ΔQ)
62-
copy!(ΔQ̃, view(ΔQ, :, 1:p))
62+
ΔQ̃ .= view(ΔQ, :, 1:p)
6363
if p < size(Q, 2)
6464
Q2 = view(Q, :, (p + 1):size(Q, 2))
6565
ΔQ2 = view(ΔQ, :, (p + 1):size(Q, 2))
@@ -91,9 +91,9 @@ function qr_pullback!(
9191
M = zero!(similar(R, (p, p)))
9292
if !iszerotangent(ΔR)
9393
ΔR11 = view(ΔR, 1:p, 1:p)
94-
M = mul!(M, ΔR11, R11', 1, 1)
94+
M += ΔR11 * R11'
9595
end
96-
M = mul!(M, Q1', ΔQ̃, -1, 1)
96+
M -= Q1' * ΔQ̃
9797
view(M, lowertriangularind(M)) .= conj.(view(M, uppertriangularind(M)))
9898
if eltype(M) <: Complex
9999
Md = diagview(M)

test/mooncake.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ for T in BLASFloats, n in (17, m, 23)
1616
TestSuite.seed_rng!(123)
1717
if CUDA.functional()
1818
TestSuite.test_mooncake(CuMatrix{T}, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
19-
#n == m && TestSuite.test_mooncake(Diagonal{T, CuVector{T}}, m; atol = m * TestSuite.precision(T), rtol = m * TestSuite.precision(T))
19+
n == m && TestSuite.test_mooncake(Diagonal{T, CuVector{T}}, m; atol = m * TestSuite.precision(T), rtol = m * TestSuite.precision(T))
2020
end
2121
#=if AMDGPU.functional()
2222
TestSuite.test_mooncake(ROCMatrix{T}, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))

0 commit comments

Comments
 (0)