Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -111,40 +111,50 @@ function _prepare_mixed_sparse_jacobian_aux_aux(
groups_forward = column_groups(coloring_result)
groups_reverse = row_groups(coloring_result)

seeds_forward = [DI.multibasis(x, eachindex(x)[group]) for group in groups_forward]
seeds_reverse = [DI.multibasis(y, eachindex(y)[group]) for group in groups_reverse]

compressed_matrix_forward = stack(_ -> vec(similar(y)), groups_forward; dims=2)
compressed_matrix_reverse = stack(_ -> vec(similar(x)), groups_reverse; dims=1)

batched_seeds_forward = [
ntuple(b -> seeds_forward[1 + ((a - 1) * Bf + (b - 1)) % Nf], Val(Bf)) for a in 1:Af
]
batched_seeds_reverse = [
ntuple(b -> seeds_reverse[1 + ((a - 1) * Br + (b - 1)) % Nr], Val(Br)) for a in 1:Ar
]
# Handle forward direction
if isempty(groups_forward)
seeds_forward = typeof(DI.multibasis(x, Int[]))[]
compressed_matrix_forward = zeros(eltype(y), length(y), 0)
batched_seeds_forward = NTuple{Bf,typeof(DI.multibasis(x, Int[]))}[]
batched_results_forward = NTuple{Bf,typeof(similar(y))}[]
dummy_seeds_forward = ntuple(_ -> DI.multibasis(x, Int[]), Val(Bf))
else
seeds_forward = [DI.multibasis(x, eachindex(x)[group]) for group in groups_forward]
compressed_matrix_forward = stack(_ -> vec(similar(y)), groups_forward; dims=2)
batched_seeds_forward = [ntuple(b -> seeds_forward[1+((a-1)*Bf+(b-1))%Nf], Val(Bf)) for a in 1:Af]
batched_results_forward = [ntuple(b -> similar(y), Val(Bf)) for _ in batched_seeds_forward]
dummy_seeds_forward = batched_seeds_forward[1]
end

batched_results_forward = [
ntuple(b -> similar(y), Val(Bf)) for _ in batched_seeds_forward
]
batched_results_reverse = [
ntuple(b -> similar(x), Val(Br)) for _ in batched_seeds_reverse
]
# Handle reverse direction
if isempty(groups_reverse)
seeds_reverse = typeof(DI.multibasis(y, Int[]))[]
compressed_matrix_reverse = zeros(eltype(x), 0, length(x))
batched_seeds_reverse = NTuple{Br,typeof(DI.multibasis(y, Int[]))}[]
batched_results_reverse = NTuple{Br,typeof(similar(x))}[]
dummy_seeds_reverse = ntuple(_ -> DI.multibasis(y, Int[]), Val(Br))
else
seeds_reverse = [DI.multibasis(y, eachindex(y)[group]) for group in groups_reverse]
compressed_matrix_reverse = stack(_ -> vec(similar(x)), groups_reverse; dims=1)
batched_seeds_reverse = [ntuple(b -> seeds_reverse[1+((a-1)*Br+(b-1))%Nr], Val(Br)) for a in 1:Ar]
batched_results_reverse = [ntuple(b -> similar(x), Val(Br)) for _ in batched_seeds_reverse]
dummy_seeds_reverse = batched_seeds_reverse[1]
end

pushforward_prep = DI.prepare_pushforward_nokwarg(
strict,
f_or_f!y...,
DI.forward_backend(dense_backend),
x,
batched_seeds_forward[1],
dummy_seeds_forward,
contexts...;
)
pullback_prep = DI.prepare_pullback_nokwarg(
strict,
f_or_f!y...,
DI.reverse_backend(dense_backend),
x,
batched_seeds_reverse[1],
dummy_seeds_reverse,
contexts...;
)

Expand Down Expand Up @@ -195,20 +205,40 @@ function _sparse_jacobian_aux!(
Nf = batch_size_settings_forward.N
Nr = batch_size_settings_reverse.N

# Get dummy seeds based on whether batched seeds are empty
dummy_seeds_forward = if isempty(batched_seeds_forward)
ntuple(_ -> DI.multibasis(x, Int[]), Val(Bf))
else
batched_seeds_forward[1]
end

dummy_seeds_reverse = if isempty(batched_seeds_reverse)
# Only evaluate y when needed for reverse mode dummy seeds
y = if length(f_or_f!y) == 1
f_or_f!y[1](x, map(DI.unwrap, contexts)...)
else
f_or_f!y[1](f_or_f!y[2], x, map(DI.unwrap, contexts)...)
f_or_f!y[2]
end
ntuple(_ -> DI.multibasis(y, Int[]), Val(Br))
else
batched_seeds_reverse[1]
end

pushforward_prep_same = DI.prepare_pushforward_same_point(
f_or_f!y...,
pushforward_prep,
DI.forward_backend(dense_backend),
x,
batched_seeds_forward[1],
dummy_seeds_forward,
contexts...,
)
pullback_prep_same = DI.prepare_pullback_same_point(
f_or_f!y...,
pullback_prep,
DI.reverse_backend(dense_backend),
x,
batched_seeds_reverse[1],
dummy_seeds_reverse,
contexts...,
)

Expand Down
24 changes: 24 additions & 0 deletions DifferentiationInterfaceTest/src/scenarios/sparse.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,26 @@
function mixedmode_empty_coloring_scenario()
sparsity_detector = TracerSparsityDetector()
function f!(y, x)
return y .= x
end
backend = AutoSparse(
MixedMode(AutoForwardDiff(), AutoMooncake());
sparsity_detector,
coloring_algorithm=GreedyColoringAlgorithm(; postprocessing=true),
)
N = 50
x = zeros(N)
y = zeros(N)
return Scenario{:jacobian,:in}(
f!,
y,
x;
prep_args=(; y=zero(y), x=zeros(N), contexts=()),
res1=zeros(N, N),
backend=backend,
name="mixedmode_empty_coloring",
)
end
## Vector to vector

diffsquare(x::AbstractVector)::AbstractVector = diff(x) .^ 2
Expand Down Expand Up @@ -397,6 +420,7 @@ function sparse_scenarios(;

final_scens = Scenario[]
append!(final_scens, scens)
push!(final_scens, mixedmode_empty_coloring_scenario())
include_constantified && append!(final_scens, constantify(scens))
include_cachified && append!(final_scens, cachify(scens; use_tuples))
include_constantorcachified && append!(final_scens, constantorcachify(scens))
Expand Down
2 changes: 1 addition & 1 deletion README.md
Loading