Skip to content

Commit ce1e176

Browse files
authored
update arrays nested functions in build_function
1 parent d5abfa3 commit ce1e176

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

src/build_function.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,14 @@ function _build_function(target::JuliaTarget, op::Operation, args...;
131131
end
132132
end
133133

134+
# Detect heterogeneous element types of "arrays of arrays of matrices/sparce matrices"
135+
function is_array_array_matrix(F)
136+
return isa(F, AbstractVector) && all(isa.(F, AbstractArray{<:AbstractMatrix}))
137+
end
138+
function is_array_array_sparse_matrix(F)
139+
return isa(F, AbstractVector) && all(isa.(F, AbstractArray{<:AbstractSparseMatrix}))
140+
end
141+
134142
function _build_function(target::JuliaTarget, rhss, args...;
135143
conv = simplified_expr, expression = Val{true},
136144
checkbounds = false, constructor=nothing,
@@ -171,9 +179,9 @@ function _build_function(target::JuliaTarget, rhss, args...;
171179
_rhss = rhss
172180
end
173181

174-
if eltype(eltype(rhss)) <: SparseMatrixCSC # Array of arrays of sparse matrices
182+
if is_array_array_sparse_matrix(rhss) # Array of arrays of sparse matrices
175183
ip_sys_exprs = reduce(vcat,[vec(reduce(vcat,[vec([:($X[$i][$j].nzval[$k] = $(conv(rhs))) for (k, rhs) enumerate(rhsel2.nzval)]) for (j, rhsel2) enumerate(rhsel)], init=Expr[])) for (i,rhsel) enumerate(_rhss)],init=Expr[])
176-
elseif eltype(eltype(rhss)) <: AbstractArray # Array of arrays of arrays
184+
elseif is_array_array_matrix(rhss) # Array of arrays of arrays
177185
ip_sys_exprs = reduce(vcat,[vec(reduce(vcat,[vec([:($X[$i][$j][$k] = $(conv(rhs))) for (k, rhs) enumerate(rhsel2)]) for (j, rhsel2) enumerate(rhsel)], init=Expr[])) for (i,rhsel) enumerate(_rhss)], init=Expr[])
178186
elseif eltype(rhss) <: SparseMatrixCSC # Array of sparse matrices
179187
ip_sys_exprs = reduce(vcat,[vec([:($X[$i].nzval[$j] = $(conv(rhs))) for (j, rhs) enumerate(rhsel.nzval)]) for (i,rhsel) enumerate(_rhss)], init=Expr[])

0 commit comments

Comments
 (0)