Skip to content

Commit e4e7c8b

Browse files
refactor: merge fill_u0 and fill_p implementations
1 parent af35520 commit e4e7c8b

File tree

1 file changed

+16
-54
lines changed

1 file changed

+16
-54
lines changed

src/remake.jl

Lines changed: 16 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -560,62 +560,24 @@ function _updated_u0_p_internal(
560560
end
561561

562562
function fill_u0(prob, u0; defs = nothing, use_defaults = false)
563-
vsyms = variable_symbols(prob)
564-
idx_to_vsym = anydict(variable_index(prob, sym) => sym for sym in vsyms)
565-
sym_to_idx = anydict()
566-
idx_to_sym = anydict()
567-
idx_to_val = anydict()
568-
for (k, v) in u0
569-
v === nothing && continue
570-
idx = variable_index(prob, k)
571-
idx === nothing && continue
572-
if !(idx isa AbstractArray) || symbolic_type(k) != ArraySymbolic()
573-
idx = (idx,)
574-
k = (k,)
575-
v = (v,)
576-
end
577-
for (kk, vv, ii) in zip(k, v, idx)
578-
sym_to_idx[kk] = ii
579-
kk = idx_to_vsym[ii]
580-
sym_to_idx[kk] = ii
581-
idx_to_sym[ii] = kk
582-
idx_to_val[ii] = vv
583-
end
584-
end
585-
for sym in vsyms
586-
haskey(sym_to_idx, sym) && continue
587-
idx = variable_index(prob, sym)
588-
haskey(idx_to_val, idx) && continue
589-
sym_to_idx[sym] = idx
590-
idx_to_sym[idx] = sym
591-
idx_to_val[idx] = if defs !== nothing &&
592-
(defval = varmap_get(defs, sym)) !== nothing &&
593-
(symbolic_type(defval) != NotSymbolic() || use_defaults)
594-
defval
595-
else
596-
getu(prob, sym)(prob)
597-
end
598-
end
599-
newvals = anydict()
600-
for (idx, val) in idx_to_val
601-
newvals[idx_to_sym[idx]] = val
602-
end
603-
for (k, v) in u0
604-
haskey(sym_to_idx, k) && continue
605-
newvals[k] = v
606-
end
607-
return newvals
563+
fill_vars(prob, u0; defs, use_defaults, allsyms = variable_symbols(prob),
564+
index_function = variable_index)
608565
end
609566

610567
function fill_p(prob, p; defs = nothing, use_defaults = false)
611-
psyms = parameter_symbols(prob)
612-
idx_to_psym = anydict(parameter_index(prob, sym) => sym for sym in psyms)
568+
fill_vars(prob, p; defs, use_defaults, allsyms = parameter_symbols(prob),
569+
index_function = parameter_index)
570+
end
571+
572+
function fill_vars(
573+
prob, varmap; defs = nothing, use_defaults = false, allsyms, index_function)
574+
idx_to_vsym = anydict(variable_index(prob, sym) => sym for sym in allsyms)
613575
sym_to_idx = anydict()
614576
idx_to_sym = anydict()
615577
idx_to_val = anydict()
616-
for (k, v) in p
578+
for (k, v) in varmap
617579
v === nothing && continue
618-
idx = parameter_index(prob, k)
580+
idx = index_function(prob, k)
619581
idx === nothing && continue
620582
if !(idx isa AbstractArray) || symbolic_type(k) != ArraySymbolic()
621583
idx = (idx,)
@@ -624,15 +586,15 @@ function fill_p(prob, p; defs = nothing, use_defaults = false)
624586
end
625587
for (kk, vv, ii) in zip(k, v, idx)
626588
sym_to_idx[kk] = ii
627-
kk = idx_to_psym[ii]
589+
kk = idx_to_vsym[ii]
628590
sym_to_idx[kk] = ii
629591
idx_to_sym[ii] = kk
630592
idx_to_val[ii] = vv
631593
end
632594
end
633-
for sym in psyms
595+
for sym in allsyms
634596
haskey(sym_to_idx, sym) && continue
635-
idx = parameter_index(prob, sym)
597+
idx = index_function(prob, sym)
636598
haskey(idx_to_val, idx) && continue
637599
sym_to_idx[sym] = idx
638600
idx_to_sym[idx] = sym
@@ -641,14 +603,14 @@ function fill_p(prob, p; defs = nothing, use_defaults = false)
641603
(symbolic_type(defval) != NotSymbolic() || use_defaults)
642604
defval
643605
else
644-
getp(prob, sym)(prob)
606+
getsym(prob, sym)(prob)
645607
end
646608
end
647609
newvals = anydict()
648610
for (idx, val) in idx_to_val
649611
newvals[idx_to_sym[idx]] = val
650612
end
651-
for (k, v) in p
613+
for (k, v) in varmap
652614
haskey(sym_to_idx, k) && continue
653615
newvals[k] = v
654616
end

0 commit comments

Comments
 (0)