Skip to content

Commit 8762d2b

Browse files
authored
add functions for "arrays of arrays"
1 parent b274ab5 commit 8762d2b

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

src/build_function.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,13 @@ function _build_function(target::JuliaTarget, op::Operation, args...;
131131
end
132132
end
133133

134+
# Detect heterogeneous element types of "arrays of matrices/sparce matrices"
135+
function is_array_matrix(F)
136+
return isa(F, AbstractVector) && all(isa.(F, AbstractArray))
137+
end
138+
function is_array_sparse_matrix(F)
139+
return isa(F, AbstractVector) && all(isa.(F, AbstractSparseMatrix))
140+
end
134141
# Detect heterogeneous element types of "arrays of arrays of matrices/sparce matrices"
135142
function is_array_array_matrix(F)
136143
return isa(F, AbstractVector) && all(isa.(F, AbstractArray{<:AbstractMatrix}))
@@ -183,9 +190,9 @@ function _build_function(target::JuliaTarget, rhss, args...;
183190
ip_sys_exprs = reduce(vcat,[vec(reduce(vcat,[vec([:($X[$i][$j].nzval[$k] = $(conv(rhs))) for (k, rhs) enumerate(rhsel2.nzval)]) for (j, rhsel2) enumerate(rhsel)], init=Expr[])) for (i,rhsel) enumerate(_rhss)],init=Expr[])
184191
elseif is_array_array_matrix(rhss) # Array of arrays of arrays
185192
ip_sys_exprs = reduce(vcat,[vec(reduce(vcat,[vec([:($X[$i][$j][$k] = $(conv(rhs))) for (k, rhs) enumerate(rhsel2)]) for (j, rhsel2) enumerate(rhsel)], init=Expr[])) for (i,rhsel) enumerate(_rhss)], init=Expr[])
186-
elseif eltype(rhss) <: SparseMatrixCSC # Array of sparse matrices
193+
elseif is_array_sparse_matrix(rhss) # Array of sparse matrices
187194
ip_sys_exprs = reduce(vcat,[vec([:($X[$i].nzval[$j] = $(conv(rhs))) for (j, rhs) enumerate(rhsel.nzval)]) for (i,rhsel) enumerate(_rhss)], init=Expr[])
188-
elseif eltype(rhss) <: AbstractArray # Array of arrays
195+
elseif is_array_matrix(rhss) # Array of arrays
189196
ip_sys_exprs = reduce(vcat,[vec([:($X[$i][$j] = $(conv(rhs))) for (j, rhs) enumerate(rhsel)]) for (i,rhsel) enumerate(_rhss)], init=Expr[])
190197
elseif rhss isa SparseMatrixCSC
191198
ip_sys_exprs = [:($X.nzval[$i] = $(conv(rhs))) for (i, rhs) enumerate(_rhss)]

0 commit comments

Comments
 (0)