Skip to content

Commit 934937b

Browse files
feat: add flatten = false kwarg to reorder_parameters
1 parent 21325ca commit 934937b

File tree

1 file changed

+29
-12
lines changed

1 file changed

+29
-12
lines changed

src/systems/index_cache.jl

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,7 @@ function reorder_parameters(
502502
end
503503
end
504504

505-
function reorder_parameters(ic::IndexCache, ps; drop_missing = false)
505+
function reorder_parameters(ic::IndexCache, ps; drop_missing = false, flatten = true)
506506
isempty(ps) && return ()
507507
param_buf = if ic.tunable_buffer_size.length == 0
508508
()
@@ -555,20 +555,37 @@ function reorder_parameters(ic::IndexCache, ps; drop_missing = false)
555555
end
556556
end
557557

558-
result = broadcast.(
559-
unwrap, (
560-
param_buf..., initials_buf..., disc_buf..., const_buf..., nonnumeric_buf...))
558+
param_buf = broadcast.(unwrap, param_buf)
559+
initials_buf = broadcast.(unwrap, initials_buf)
560+
disc_buf = broadcast.(unwrap, disc_buf)
561+
const_buf = broadcast.(unwrap, const_buf)
562+
nonnumeric_buf = broadcast.(unwrap, nonnumeric_buf)
563+
561564
if drop_missing
562-
result = map(result) do buf
563-
filter(buf) do sym
564-
return !isequal(sym, unwrap(variable(:DEF)))
565-
end
565+
filterer = !isequal(unwrap(variable(:DEF)))
566+
param_buf = filter.(filterer, param_buf)
567+
initials_buf = filter.(filterer, initials_buf)
568+
disc_buf = filter.(filterer, disc_buf)
569+
const_buf = filter.(filterer, const_buf)
570+
nonnumeric_buf = filter.(filterer, nonnumeric_buf)
571+
end
572+
573+
if flatten
574+
result = (
575+
param_buf..., initials_buf..., disc_buf..., const_buf..., nonnumeric_buf...)
576+
if all(isempty, result)
577+
return ()
566578
end
579+
return result
580+
else
581+
if isempty(param_buf)
582+
param_buf = ((),)
583+
end
584+
if isempty(initials_buf)
585+
initials_buf = ((),)
586+
end
587+
return (param_buf..., initials_buf..., disc_buf, const_buf, nonnumeric_buf)
567588
end
568-
if all(isempty, result)
569-
return ()
570-
end
571-
return result
572589
end
573590

574591
# Given a parameter index, find the index of the buffer it is in when

0 commit comments

Comments
 (0)