Skip to content

Commit d814d37

Browse files
committed
add more array types
1 parent 1590e9f commit d814d37

File tree

3 files changed

+31
-4
lines changed

3 files changed

+31
-4
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
1111
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
1212
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
1313
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
14+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1415
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
1516
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
1617
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

src/code.jl

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module Code
22

3-
using StaticArrays, LabelledArrays, SparseArrays
3+
using StaticArrays, LabelledArrays, SparseArrays, LinearAlgebra
44

55
export toexpr, Assignment, (), Let, Func, DestructuredArgs, LiteralExpr,
66
SetArray, MakeArray, MakeSparseArray, MakeTuple, AtIndex,
@@ -375,7 +375,7 @@ function toexpr(a::MakeArray, st)
375375
end
376376

377377
## Array
378-
@inline function _create_array(::Union{Type{<:Array},Type{<:SubArray}}, T, ::Val{dims}, elems...) where dims
378+
@inline function _create_array(::Type{<:Array}, T, ::Val{dims}, elems...) where dims
379379
arr = Array{T}(undef, dims)
380380
@assert prod(dims) == nfields(elems)
381381
@inbounds for i=1:prod(dims)
@@ -384,15 +384,23 @@ end
384384
arr
385385
end
386386

387-
@inline function create_array(A::Union{Type{<:Array},Type{<:SubArray}}, T, d::Val, elems...)
387+
@inline function create_array(A::Type{<:Array}, T, d::Val, elems...)
388388
_create_array(A, T, d, elems...)
389389
end
390390

391-
@inline function create_array(A::Union{Type{<:Array},Type{<:SubArray}}, ::Nothing, d::Val{dims}, elems...) where dims
391+
@inline function create_array(A::Type{<:Array}, ::Nothing, d::Val{dims}, elems...) where dims
392392
T = promote_type(map(typeof, elems)...)
393393
_create_array(A, T, d, elems...)
394394
end
395395

396+
@inline function create_array(A::Type{<:SubArray{T,N,P,I,L}}, S, d::Val, elems...) where {T,N,P,I,L}
397+
create_array(P, S, d, elems...)
398+
end
399+
400+
@inline function create_array(A::Type{<:PermutedDimsArray{T,N,perm,iperm,P}}, S, d::Val, elems...) where {T,N,perm,iperm,P}
401+
create_array(P, S, d, elems...)
402+
end
403+
396404
## Matrix
397405

398406
@inline function create_array(::Type{<:Matrix}, ::Nothing, ::Val{dims}, elems...) where dims
@@ -403,6 +411,14 @@ end
403411
Base.typed_hvcat(T, dims, elems...)
404412
end
405413

414+
@inline function create_array(A::Type{<:Transpose{T,P}}, S, d::Val, elems...) where {T,P}
415+
create_array(P, S, d, elems...)
416+
end
417+
418+
@inline function create_array(A::Type{<:UpperTriangular{T,P}}, S, d::Val, elems...) where {T,P}
419+
create_array(P, S, d, elems...)
420+
end
421+
406422
## SArray
407423
@inline function create_array(::Type{<:SArray}, ::Nothing, ::Val{dims}, elems...) where dims
408424
SArray{Tuple{dims...}}(elems...)

test/code.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using SymbolicUtils.Code: LazyState
44
using StaticArrays
55
using LabelledArrays
66
using SparseArrays
7+
using LinearAlgebra
78

89
test_repr(a, b) = @test repr(Base.remove_linenums!(a)) == repr(Base.remove_linenums!(b))
910

@@ -92,6 +93,15 @@ test_repr(a, b) = @test repr(Base.remove_linenums!(a)) == repr(Base.remove_linen
9293
@test eval(toexpr(Let([a 1, b 2, arr [1,2]],
9394
MakeArray(view([a,b,a+b,a/b], :), arr)))) == [1, 2, 3, 1/2]
9495

96+
@test eval(toexpr(Let([a 1, b 2, arr [1,2]],
97+
MakeArray(PermutedDimsArray([a b;a+b a/b], (1,2)), arr)))) == [1 2 ; 3 1/2]
98+
99+
@test eval(toexpr(Let([a 1, b 2, arr [1,2]],
100+
MakeArray(transpose([a b;a+b a/b]), arr)))) == [1 3;2 1/2]
101+
102+
@test eval(toexpr(Let([a 1, b 2, arr [1,2]],
103+
MakeArray(UpperTriangular([a b;a+b a/b]), arr)))) == [1 2;0 1/2]
104+
95105
@test eval(toexpr(Let([a 1, b 2, arr [1,2]],
96106
MakeArray([a b;a+b a/b], arr)))) == [1 2; 3 1/2]
97107

0 commit comments

Comments
 (0)