@@ -643,6 +643,32 @@ function concrete_getu(indp, syms::AbstractVector)
643
643
return Base. Fix1 (reduce, vcat) ∘ getu (indp, split_syms)
644
644
end
645
645
646
+ """
647
+ $(TYPEDEF)
648
+
649
+ A callable struct which applies `p_constructor` to possibly nested arrays. It also
650
+ ensures that views (including nested ones) are concretized.
651
+ """
652
+ struct PConstructorApplicator{F}
653
+ p_constructor:: F
654
+ end
655
+
656
+ function (pca:: PConstructorApplicator )(x:: AbstractArray )
657
+ pca. p_constructor (x)
658
+ end
659
+
660
+ function (pca:: PConstructorApplicator{typeof(identity)} )(x:: SubArray )
661
+ collect (x)
662
+ end
663
+
664
+ function (pca:: PConstructorApplicator{typeof(identity)} )(x:: SubArray{<:AbstractArray} )
665
+ collect (pca .(x))
666
+ end
667
+
668
+ function (pca:: PConstructorApplicator )(x:: AbstractArray{<:AbstractArray} )
669
+ pca. p_constructor (pca .(x))
670
+ end
671
+
646
672
"""
647
673
$(TYPEDSIGNATURES)
648
674
@@ -657,6 +683,7 @@ takes a value provider of `srcsys` and a value provider of `dstsys` and returns
657
683
"""
658
684
function get_mtkparameters_reconstructor (srcsys:: AbstractSystem , dstsys:: AbstractSystem ;
659
685
initials = false , unwrap_initials = false , p_constructor = identity)
686
+ p_constructor = PConstructorApplicator (p_constructor)
660
687
# if we call `getu` on this (and it were able to handle empty tuples) we get the
661
688
# fields of `MTKParameters` except caches.
662
689
syms = reorder_parameters (
@@ -698,15 +725,16 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac
698
725
else
699
726
ic = get_index_cache (dstsys)
700
727
blockarrsizes = Tuple (map (ic. discrete_buffer_sizes) do bufsizes
701
- map (x -> x. length, bufsizes)
728
+ p_constructor ( map (x -> x. length, bufsizes) )
702
729
end )
703
730
# discretes need to be blocked arrays
704
731
# the `getu` returns a tuple of arrays corresponding to `p.discretes`
705
732
# `Base.Fix1(...)` applies `p_constructor` to each of the arrays in the tuple
706
733
# `Base.Fix2(...)` does `BlockedArray.(tuple_of_arrs, blockarrsizes)` returning a
707
734
# tuple of `BlockedArray`s
708
735
Base. Fix2 (Broadcast. BroadcastFunction (BlockedArray), blockarrsizes) ∘
709
- Base. Fix1 (broadcast, p_constructor) ∘ getu (srcsys, syms[3 ])
736
+ Base. Fix1 (broadcast, p_constructor) ∘
737
+ getu (srcsys, syms[3 ])
710
738
end
711
739
rest_getters = map (Base. tail (Base. tail (Base. tail (syms)))) do buf
712
740
if buf == ()
@@ -1307,7 +1335,7 @@ function process_SciMLProblem(
1307
1335
if ! (pType <: AbstractArray )
1308
1336
pType = Array
1309
1337
end
1310
- p = MTKParameters (sys, op; floatT = floatT, container_type = pType, p_constructor)
1338
+ p = MTKParameters (sys, op; floatT = floatT, p_constructor)
1311
1339
else
1312
1340
p = p_constructor (better_varmap_to_vars (op, ps; tofloat, container_type = pType))
1313
1341
end
0 commit comments