Skip to content

Commit 224540d

Browse files
author
Simone Carlo Surace
committed
Fix type instability
1 parent 31b758a commit 224540d

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

src/rulesets/LinearAlgebra/dense.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -445,10 +445,12 @@ end
445445
project_y = ProjectTo(y)
446446
function kron_pullback(z̄)
447447
dz = reshape(unthunk(z̄), size(y, 1), size(x, 1), size(y, 2), size(x, 2))
448-
= @thunk(project_x(dot.(Ref(y), eachslice(dz; dims = (2, 4)))))
449-
ȳ = @thunk(project_y(dot.(Ref(x), eachslice(dz; dims = (1, 3)))))
448+
= @thunk(project_x(_dot_collect.(Ref(y), eachslice(dz; dims = (2, 4)))))
449+
ȳ = @thunk(project_y(_dot_collect.(Ref(x), eachslice(dz; dims = (1, 3)))))
450450
return NoTangent(), x̄, ȳ
451451
end
452452
return kron(x, y), kron_pullback
453453
end
454-
end
454+
455+
_dot_collect(A::AbstractMatrix, B::SubArray) = dot(A, B)
456+
_dot_collect(A::Diagonal, B::SubArray) = dot(A, collect(B))

0 commit comments

Comments
 (0)