@@ -65,7 +65,7 @@ make_mooncake_fdata(x::Diagonal) = Mooncake.FData((diag = make_mooncake_tangent(
6565
6666# no `alg` argument
6767function _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, :: Nothing , rdata)
68- dA_copy = make_mooncake_tangent (copy(ΔA))
68+ dA_copy = make_mooncake_fdata (copy(ΔA))
6969 A_copy = copy(A)
7070 dargs_copy = Δargs isa Tuple ? make_mooncake_fdata.(deepcopy(Δargs)) : make_mooncake_fdata(deepcopy(Δargs))
7171 copy_out, copy_pb!! = rrule(Mooncake. CoDual(f_c, Mooncake. NoFData()), Mooncake. CoDual(A_copy, dA_copy), Mooncake. CoDual(args, dargs_copy))
7575
7676# `alg` argument
7777function _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, alg, rdata)
78- dA_copy = make_mooncake_tangent (copy(ΔA))
78+ dA_copy = make_mooncake_fdata (copy(ΔA))
7979 A_copy = copy(A)
8080 dargs_copy = Δargs isa Tuple ? make_mooncake_fdata.(deepcopy(Δargs)) : make_mooncake_fdata(deepcopy(Δargs))
8181 copy_out, copy_pb!! = rrule(Mooncake. CoDual(f_c, Mooncake. NoFData()), Mooncake. CoDual(A_copy, dA_copy), Mooncake. CoDual(args, dargs_copy), Mooncake. CoDual(alg, Mooncake. NoFData()))
@@ -84,7 +84,7 @@ function _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, alg, rdata)
8484end
8585
8686function _get_inplace_derivative(f!, A, ΔA, args, Δargs, :: Nothing , rdata)
87- dA_inplace = make_mooncake_tangent (copy(ΔA))
87+ dA_inplace = make_mooncake_fdata (copy(ΔA))
8888 A_inplace = copy(A)
8989 dargs_inplace = Δargs isa Tuple ? make_mooncake_fdata.(deepcopy(Δargs)) : make_mooncake_fdata(deepcopy(Δargs))
9090 # not every f! has a handwritten rrule!!
@@ -103,7 +103,7 @@ function _get_inplace_derivative(f!, A, ΔA, args, Δargs, ::Nothing, rdata)
103103end
104104
105105function _get_inplace_derivative(f!, A, ΔA, args, Δargs, alg, rdata)
106- dA_inplace = make_mooncake_tangent (copy(ΔA))
106+ dA_inplace = make_mooncake_fdata (copy(ΔA))
107107 A_inplace = copy(A)
108108 dargs_inplace = Δargs isa Tuple ? make_mooncake_fdata.(deepcopy(Δargs)) : make_mooncake_fdata(deepcopy(Δargs))
109109 # not every f! has a handwritten rrule!!
@@ -143,9 +143,9 @@ function test_pullbacks_match(f!, f, A, args, Δargs, alg = nothing; rdata = Moo
143143 sig = isnothing(alg) ? Tuple{typeof(f_c), typeof(A), typeof(args)} : Tuple{typeof(f_c), typeof(A), typeof(args), typeof(alg)}
144144 rvs_interp = Mooncake. get_interpreter(Mooncake. ReverseMode)
145145 rrule = Mooncake. build_rrule(rvs_interp, sig)
146- ΔA = randn!(similar(A))
146+ ΔA = A isa Diagonal ? Diagonal(randn!(similar(A . diag))) : randn!(similar(A))
147147
148- dA_copy = _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, alg, rdata)
148+ dA_copy = _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, alg, rdata)
149149 dA_inplace = _get_inplace_derivative(f!, A, ΔA, args, Δargs, alg, rdata)
150150
151151 dA_inplace_ = Mooncake. arrayify(A, dA_inplace)[2 ]
0 commit comments