Skip to content

Commit 3666df0

Browse files
refactor: use new array_literal instead of hvncat for arrays of symbolics
1 parent bc20620 commit 3666df0

File tree

7 files changed

+26
-24
lines changed

7 files changed

+26
-24
lines changed

docs/src/manual/variants.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,8 @@ variant. This variant can be constructed using `Const{T}(val)` or `BSImpl.Const{
108108
where `T` is the appropriate `vartype`.
109109

110110
The `Const` constructors have an additional special behavior. If given an array of symbolics
111-
(or array of array of ... symbolics), it will return a `Term` (see below) with `hvncat` as
112-
the operation. This allows standard symbolic operations (such as [`substitute`](@ref)) to
111+
(or array of array of ... symbolics), it will return a `Term` (see below) with [`array_literal`](@ref)
112+
as the operation. This allows standard symbolic operations (such as [`substitute`](@ref)) to
113113
work on arrays of symbolics without excessive special-case handling and improved
114114
type-stability.
115115

@@ -409,6 +409,7 @@ unwrap_const
409409
### Inner constructors
410410

411411
```@docs
412+
SymbolicUtils.array_literal
412413
SymbolicUtils.BSImpl.Const
413414
SymbolicUtils.BSImpl.Sym
414415
SymbolicUtils.BSImpl.Term

src/methods.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ end
238238
_sequential_promote(T::TypeT) = T
239239

240240

241-
function promote_symtype(::typeof(hvncat), Tp::TypeT, Ts::TypeT...)
241+
function promote_symtype(::typeof(array_literal), Tp::TypeT, Ts::TypeT...)
242242
@assert Tp <: Tuple
243243
return Array{_sequential_promote(Ts...), length(Tp.parameters)}
244244
end

src/types.jl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1484,6 +1484,8 @@ function _is_tuple_of_symbolics(O::Tuple)
14841484
end
14851485
_is_tuple_of_symbolics(O) = false
14861486

1487+
array_literal(sz::NTuple{N, Int}, args...) where {N} = reshape(Base.vect(args...), sz)
1488+
14871489
"""
14881490
BSImpl.Const{T}(val) where {T}
14891491
@@ -1500,7 +1502,7 @@ arrays/tuples of symbolics to symbolic expressions.
15001502
This is the low-level constructor for constant expressions. It handles several special cases:
15011503
1. If `val` is already a `BasicSymbolic{T}`, returns it unchanged
15021504
2. If `val` is a `BasicSymbolic` of a different variant type, throws an error
1503-
3. If `val` is an array containing symbolic elements, creates a `Term` with `hvncat` operation
1505+
3. If `val` is an array containing symbolic elements, creates a `Term` with [`array_literal`](@ref) operation
15041506
4. If `val` is a tuple containing symbolic elements, creates a `Term` with `tuple` operation
15051507
5. Otherwise, creates a `Const` variant wrapping the value
15061508
@@ -1520,15 +1522,15 @@ The `unsafe` flag skips hash consing for performance in internal operations.
15201522
elseif val isa BasicSymbolic{TreeReal}
15211523
error("Cannot construct `BasicSymbolic{$T}` from `BasicSymbolic{TreeReal}`.")
15221524
elseif val isa AbstractArray && _is_array_of_symbolics(val)
1523-
args = ArgsT{T}((BSImpl.Const{T}(size(val); unsafe), BSImpl.Const{T}(false; unsafe)))
1524-
sizehint!(args, length(val) + 2)
1525+
args = ArgsT{T}((BSImpl.Const{T}(size(val); unsafe),))
1526+
sizehint!(args, length(val) + 1)
15251527
type = Union{}
15261528
for v in val
15271529
push!(args, BSImpl.Const{T}(v))
15281530
type = promote_type(type, symtype(v))
15291531
end
15301532
shape = ShapeVecT(axes(val))
1531-
return BSImpl.Term{T}(hvncat, args; type = Array{type, ndims(val)}, shape, unsafe)
1533+
return BSImpl.Term{T}(array_literal, args; type = Array{type, ndims(val)}, shape, unsafe)
15321534
elseif val isa Tuple && _is_tuple_of_symbolics(val)
15331535
args = ArgsT{T}()
15341536
sizehint!(args, length(val))
@@ -2470,7 +2472,7 @@ function TermInterface.maketerm(::Type{BasicSymbolic{T}}, f, args, metadata; @no
24702472
@set! res.metadata = metadata
24712473
end
24722474
return res::BasicSymbolic{T}
2473-
elseif f === hvncat
2475+
elseif f === array_literal
24742476
sh = ShapeVecT()
24752477
for dim in unwrap_const(args[1])
24762478
push!(sh, 1:dim)
@@ -3998,8 +4000,8 @@ function __stable_getindex(arr::BasicSymbolic{T}, sidxs::StableIndex) where {T}
39984000
sh::ShapeVecT = shape(arr)
39994001
@match arr begin
40004002
BSImpl.Const(; val) => return Const{T}(scalar_index(val, as_linear_idx(sh, sidxs)))
4001-
BSImpl.Term(; f, args) && if f === hvncat end => begin
4002-
return args[2 + as_linear_idx(sh, sidxs)]
4003+
BSImpl.Term(; f, args) && if f === array_literal end => begin
4004+
return args[1 + as_linear_idx(sh, sidxs)]
40034005
end
40044006
BSImpl.Term(; f, args) && if f isa TypeT && f <: CartesianIndex end => begin
40054007
return args[as_linear_idx(sh, sidxs)]
@@ -4069,8 +4071,8 @@ end
40694071
Base.@propagate_inbounds function _getindex(::Type{T}, arr::BasicSymbolic{T}, idxs::Union{BasicSymbolic{T}, Int, AbstractRange{Int}, Colon}...) where {T}
40704072
@match arr begin
40714073
BSImpl.Const(; val) && if all(x -> !(x isa BasicSymbolic{T}) || isconst(x), idxs) end => return Const{T}(val[unwrap_const.(idxs)...])
4072-
BSImpl.Term(; f) && if f === hvncat && all(x -> !(x isa BasicSymbolic{T}) || isconst(x), idxs) end => begin
4073-
return Const{T}(reshape(@view(arguments(arr)[3:end]), Tuple(size(arr)))[unwrap_const.(idxs)...])
4074+
BSImpl.Term(; f) && if f === array_literal && all(x -> !(x isa BasicSymbolic{T}) || isconst(x), idxs) end => begin
4075+
return Const{T}(reshape(@view(arguments(arr)[2:end]), Tuple(size(arr)))[unwrap_const.(idxs)...])
40744076
end
40754077
BSImpl.Term(; f, args) && if f isa TypeT && f <: CartesianIndex end => return args[idxs...]
40764078
BSImpl.Term(; f, args) && if f isa Operator && length(args) == 1 end => begin

test/basics.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,13 +217,13 @@ end
217217
var = Const{SymReal}(symvec)
218218
@test var isa BasicSymbolic{SymReal}
219219
@test isterm(var)
220-
@test isequal(arguments(var), Const{SymReal}.([(2,), false, h, x]))
220+
@test isequal(arguments(var), Const{SymReal}.([(2,), h, x]))
221221
@test symtype(var) == Vector{Number}
222222
@test shape(var) == ShapeVecT((1:2,))
223223
var = Const{SymReal}(symmat)
224224
@test var isa BasicSymbolic{SymReal}
225225
@test isterm(var)
226-
@test isequal(arguments(var), Const{SymReal}.([(2, 2), false, h, y, x, z]))
226+
@test isequal(arguments(var), Const{SymReal}.([(2, 2), h, y, x, z]))
227227
@test symtype(var) == Matrix{Number}
228228
@test shape(var) == ShapeVecT((1:2, 1:2))
229229
csymvec = Const{SymReal}(symvec)

test/cse.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,9 @@ end
6161
end
6262
ex = term(foo, [a^2 + b^2, b^2 + c], (a^2 + b^2, b^2 + c), c; type = Real)
6363
sorted_nodes = topological_sort(ex)
64-
@test length(sorted_nodes) == 10
65-
@test operation(sorted_nodes[8].rhs) === hvncat
66-
@test operation(sorted_nodes[9].rhs) === tuple
64+
@test length(sorted_nodes) == 9
65+
@test operation(sorted_nodes[7].rhs) === SymbolicUtils.array_literal
66+
@test operation(sorted_nodes[8].rhs) === tuple
6767
expr = quote
6868
a = 1
6969
b = 2

test/methods.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using SymbolicUtils
2-
using SymbolicUtils: Sym, Term, symtype, BasicSymbolic, Const, ArgsT, promote_symtype, promote_shape, ShapeVecT, Unknown
2+
using SymbolicUtils: Sym, Term, symtype, BasicSymbolic, Const, ArgsT, promote_symtype, promote_shape, ShapeVecT, Unknown, array_literal
33
using Test
44
import NaNMath
55
import LinearAlgebra
@@ -72,9 +72,9 @@ end
7272
@test promote_shape(identity, ShapeVecT()) == ShapeVecT()
7373
@test promote_shape(identity, ShapeVecT((1:2, 1:3))) == ShapeVecT((1:2, 1:3))
7474
end
75-
@testset "promote_symtype for hvncat" begin
76-
@test promote_symtype(hvncat, NTuple{2, Int}, Int, Float64, Int32) == Array{Float64, 2}
77-
@test promote_symtype(hvncat, NTuple{3, Int}, Int, Int, Int) == Array{Int, 3}
75+
@testset "promote_symtype for `array_literal`" begin
76+
@test promote_symtype(array_literal, NTuple{2, Int}, Int, Float64, Int32) == Array{Float64, 2}
77+
@test promote_symtype(array_literal, NTuple{3, Int}, Int, Int, Int) == Array{Int, 3}
7878
end
7979

8080
@testset "promote_symtype for rem2pi" begin

test/misc.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -227,11 +227,10 @@ end
227227
# Test array of symbolics
228228
arr = [x, y, x+y]
229229
const_arr = SymbolicUtils.Const{SymbolicUtils.SymReal}(arr)
230-
@test operation(const_arr) === hvncat
230+
@test operation(const_arr) === SymbolicUtils.array_literal
231231
args = arguments(const_arr)
232232
@test unwrap_const(args[1]) == (3,)
233-
@test unwrap_const(args[2]) == false
234-
@test length(args) == 5
233+
@test length(args) == 4
235234

236235
# Test tuple of symbolics
237236
tup = (x, y)

0 commit comments

Comments
 (0)