Skip to content

Commit aacb2c4

Browse files
refactor: rename _is_array_shape to is_array_shape
1 parent 894efcf commit aacb2c4

File tree

4 files changed

+38
-38
lines changed

4 files changed

+38
-38
lines changed

src/code.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import SymbolicUtils: @matchable, BasicSymbolic, Sym, Term, iscall, operation, a
1313
symtype, sorted_arguments, metadata, isterm, term, maketerm, unwrap_const,
1414
ArgsT, Const, SymVariant, _is_array_of_symbolics, _is_tuple_of_symbolics,
1515
ArrayOp, isarrayop, IdxToAxesT, ROArgsT, shape, Unknown, ShapeVecT,
16-
search_variables!, _is_index_variable, RangesT, IDXS_SYM, _is_array_shape
16+
search_variables!, _is_index_variable, RangesT, IDXS_SYM, is_array_shape
1717
import SymbolicIndexingInterface: symbolic_type, NotSymbolic
1818

1919
##== state management ==##
@@ -153,7 +153,7 @@ function function_to_expr(::Type{ArrayOp{T}}, O::BasicSymbolic{T}, st) where {T}
153153
output_eltype = get(st.rewrites, :arrayop_eltype, Float64)
154154
delete!(st.rewrites, :arrayop_eltype)
155155
sh = shape(O)
156-
default_output_buffer = if _is_array_shape(sh)
156+
default_output_buffer = if is_array_shape(sh)
157157
term(zeros, output_eltype, size(O))
158158
else
159159
term(zero, output_eltype)
@@ -226,7 +226,7 @@ function inplace_expr(x::BasicSymbolic{T}, outsym) where {T}
226226
new_expr = unidealize_indices(x.expr, ranges, new_ranges)
227227
loopvar_order = unique!(filter(x -> x isa BasicSymbolic{T}, vcat(reverse(x.output_idx), collect(keys(ranges)), collect(keys(new_ranges)))))
228228

229-
if _is_array_shape(sh)
229+
if is_array_shape(sh)
230230
inner_expr = SetArray(false, outsym, [AtIndex(term(CartesianIndex, x.output_idx...), term(x.reduce, term(getindex, outsym, x.output_idx...), new_expr))])
231231
else
232232
inner_expr = Assignment(outsym, term(x.reduce, outsym, new_expr))

src/methods.jl

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -223,8 +223,8 @@ for f in diadic
223223
f === NaNMath.pow && continue
224224
@eval function promote_shape(::$(typeof(f)), sh1::ShapeT, sh2::ShapeT)
225225
@nospecialize sh1 sh2
226-
_is_array_shape(sh1) && _throw_array($f, sh1, sh2)
227-
_is_array_shape(sh2) && _throw_array($f, sh1, sh2)
226+
is_array_shape(sh1) && _throw_array($f, sh1, sh2)
227+
is_array_shape(sh2) && _throw_array($f, sh1, sh2)
228228
return ShapeVecT()
229229
end
230230
end
@@ -246,7 +246,7 @@ for f in monadic
246246
else
247247
@eval function promote_shape(::$(typeof(f)), sh::ShapeT)
248248
@nospecialize sh
249-
_is_array_shape(sh) && _throw_array($f, sh)
249+
is_array_shape(sh) && _throw_array($f, sh)
250250
return ShapeVecT()
251251
end
252252
end
@@ -255,7 +255,7 @@ end
255255
error_f_symbolic(f, T) = error("$f is not defined for $T.")
256256

257257
function promote_shape(::typeof(rem2pi), sha::ShapeT, shb::ShapeT)
258-
_is_array_shape(sha) && _throw_array(rem2pi, sha, shb)
258+
is_array_shape(sha) && _throw_array(rem2pi, sha, shb)
259259
ShapeVecT()
260260
end
261261
function Base.rem2pi(x::BasicSymbolic{T}, mode::Base.RoundingMode) where {T}
@@ -287,7 +287,7 @@ end
287287
function Base.inv(x::BasicSymbolic{T}) where {T}
288288
sh = shape(x)
289289
type = promote_symtype(inv, symtype(x))
290-
if _is_array_shape(sh)
290+
if is_array_shape(sh)
291291
return Term{T}(inv, ArgsT{T}((x,)); type = type, shape = sh)
292292
else
293293
return x ^ (-1)
@@ -381,7 +381,7 @@ function Base.adjoint(s::BasicSymbolic{T}) where {T}
381381
end
382382
sh = shape(s)
383383
stype = symtype(s)
384-
if _is_array_shape(sh)
384+
if is_array_shape(sh)
385385
type = promote_symtype(adjoint, stype)
386386
newsh = promote_shape(adjoint, sh)
387387
return Term{T}(adjoint, ArgsT{T}((s,)); type, shape = newsh)
@@ -493,7 +493,7 @@ function _ndims_from_shape(sh::ShapeT)
493493
end
494494
end
495495
Base.ndims(x::BasicSymbolic) = _ndims_from_shape(shape(x))
496-
Base.broadcastable(x::BasicSymbolic) = _is_array_shape(shape(x)) ? x : Ref(x)
496+
Base.broadcastable(x::BasicSymbolic) = is_array_shape(shape(x)) ? x : Ref(x)
497497
function Base.eachindex(x::BasicSymbolic)
498498
sh = shape(x)
499499
if sh isa Unknown
@@ -506,7 +506,7 @@ function Base.collect(x::BasicSymbolic)
506506
end
507507
function Base.iterate(x::BasicSymbolic)
508508
sh = shape(x)
509-
_is_array_shape(sh) || return x, nothing
509+
is_array_shape(sh) || return x, nothing
510510
idxs = eachindex(x)
511511
idx, state = iterate(idxs)
512512
return x[idx], (idxs, state)
@@ -530,12 +530,12 @@ promote_symtype(::Type{CartesianIndex}, xs...) = CartesianIndex{length(xs)}
530530
promote_symtype(::Type{CartesianIndex{N}}, xs::Vararg{T, N}) where {T, N} = CartesianIndex{N}
531531
function promote_shape(::Type{CartesianIndex}, xs::ShapeT...)
532532
@nospecialize xs
533-
@assert all(!_is_array_shape, xs)
533+
@assert all(!is_array_shape, xs)
534534
return ShapeVecT((1:length(xs),))
535535
end
536536
function promote_shape(::Type{CartesianIndex{N}}, xs::Vararg{ShapeT, N}) where {N}
537537
@nospecialize xs
538-
@assert all(!_is_array_shape, xs)
538+
@assert all(!is_array_shape, xs)
539539
return ShapeVecT((1:length(xs),))
540540
end
541541
function Base.CartesianIndex(x::BasicSymbolic{T}, xs::BasicSymbolic{T}...) where {T}
@@ -690,7 +690,7 @@ function _copy_broadcast!(buffer::BroadcastBuffer{T}, bc::Broadcast.Broadcasted{
690690

691691
for arg in canonical_args
692692
sh = shape(arg)
693-
is_arr = _is_array_shape(sh)
693+
is_arr = is_array_shape(sh)
694694
if !is_arr
695695
push!(args, arg)
696696
continue
@@ -757,7 +757,7 @@ promote_symtype(::typeof(LinearAlgebra.dot), ::Type{T}, ::Type{S}) where {eT, T
757757

758758
function LinearAlgebra.dot(x::BasicSymbolic{T}, y::BasicSymbolic{T}) where {T}
759759
shx = shape(x)
760-
if _is_array_shape(shx)
760+
if is_array_shape(shx)
761761
sh = promote_shape(LinearAlgebra.dot, shx, shape(y))
762762
type = promote_symtype(LinearAlgebra.dot, symtype(x), symtype(y))
763763
BSImpl.Term{T}(LinearAlgebra.dot, ArgsT{T}((x, y)); type, shape = sh)
@@ -994,7 +994,7 @@ function promote_symtype(::typeof(in), ::Type{T}, ::Type{S}) where {T, S}
994994
end
995995
function promote_shape(::typeof(in), sha::ShapeT, shb::ShapeT)
996996
@nospecialize sha shb
997-
@assert _is_array_shape(shb) || throw(ArgumentError("Symbolic `in` requires an array as the second argument."))
997+
@assert is_array_shape(shb) || throw(ArgumentError("Symbolic `in` requires an array as the second argument."))
998998
return ShapeVecT()
999999
end
10001000

@@ -1013,8 +1013,8 @@ function promote_symtype(::typeof(issubset), ::Type{T}, ::Type{S}) where {T <: A
10131013
end
10141014
function promote_shape(::typeof(issubset), sha::ShapeT, shb::ShapeT)
10151015
@nospecialize sha shb
1016-
@assert _is_array_shape(sha) || throw(ArgumentError("Symbolic `issubset` requires arrays as both arguments."))
1017-
@assert _is_array_shape(shb) || throw(ArgumentError("Symbolic `issubset` requires arrays as both arguments."))
1016+
@assert is_array_shape(sha) || throw(ArgumentError("Symbolic `issubset` requires arrays as both arguments."))
1017+
@assert is_array_shape(shb) || throw(ArgumentError("Symbolic `issubset` requires arrays as both arguments."))
10181018
return ShapeVecT()
10191019
end
10201020

@@ -1036,8 +1036,8 @@ for f in [union, intersect]
10361036
end
10371037
@eval function promote_shape(::$(typeof(f)), sha::ShapeT, shb::ShapeT)
10381038
@nospecialize sha shb
1039-
@assert _is_array_shape(sha) || throw(ArgumentError("Symbolic `$($f)` requires arrays as both arguments."))
1040-
@assert _is_array_shape(shb) || throw(ArgumentError("Symbolic `$($f)` requires arrays as both arguments."))
1039+
@assert is_array_shape(sha) || throw(ArgumentError("Symbolic `$($f)` requires arrays as both arguments."))
1040+
@assert is_array_shape(shb) || throw(ArgumentError("Symbolic `$($f)` requires arrays as both arguments."))
10411041
return Unknown(1)
10421042
end
10431043
for T1 in [AbstractArray, :(BasicSymbolic{T})], T2 in [AbstractArray, :(BasicSymbolic{T})]
@@ -1064,8 +1064,8 @@ function promote_symtype(::typeof(binomial), ::Type{T}, ::Type{S}) where {T <: N
10641064
end
10651065
function promote_shape(::typeof(binomial), sha::ShapeT, shb::ShapeT)
10661066
@nospecialize sha shb
1067-
_is_array_shape(sha) && _throw_array(sha)
1068-
_is_array_shape(shb) && _throw_array(shb)
1067+
is_array_shape(sha) && _throw_array(sha)
1068+
is_array_shape(shb) && _throw_array(shb)
10691069

10701070
return ShapeVecT()
10711071
end

src/substitute.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ function _default_scalarize(f, x::BasicSymbolic{T}, ::Val{toplevel}) where {T, t
309309
@nospecialize f
310310

311311
sh = shape(x)
312-
_is_array_shape(sh) && return [x[idx] for idx in eachindex(x)]
312+
is_array_shape(sh) && return [x[idx] for idx in eachindex(x)]
313313

314314
args = arguments(x)
315315
if toplevel
@@ -324,7 +324,7 @@ function scalarize(x::BasicSymbolic{T}, ::Val{toplevel} = Val{false}()) where {T
324324
sh isa Unknown && return x
325325
@match x begin
326326
BSImpl.Const(;) => return x
327-
BSImpl.Sym(;) => _is_array_shape(sh) ? [x[idx] for idx in eachindex(x)] : x
327+
BSImpl.Sym(;) => is_array_shape(sh) ? [x[idx] for idx in eachindex(x)] : x
328328
BSImpl.ArrayOp(; output_idx, expr, term, ranges, reduce) => begin
329329
term === nothing || return scalarize(term, Val{toplevel}())
330330
subrules = Dict()

src/types.jl

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2190,11 +2190,11 @@ end
21902190
"""))
21912191
end
21922192

2193-
_is_array_shape(sh::ShapeT) = sh isa Unknown || _ndims_from_shape(sh) > 0
2193+
is_array_shape(sh::ShapeT) = sh isa Unknown || _ndims_from_shape(sh) > 0
21942194
function _multiplied_shape(shapes)
2195-
first_arr = findfirst(_is_array_shape, shapes)
2195+
first_arr = findfirst(is_array_shape, shapes)
21962196
first_arr === nothing && return ShapeVecT(), first_arr
2197-
last_arr::Int = findlast(_is_array_shape, shapes)
2197+
last_arr::Int = findlast(is_array_shape, shapes)
21982198
first_arr == last_arr && return shapes[first_arr], first_arr
21992199

22002200
sh1::ShapeT = shapes[first_arr]
@@ -2218,7 +2218,7 @@ function _multiplied_shape(shapes)
22182218
for i in (first_arr + 1):last_arr
22192219
sh = shapes[i]
22202220
ndims_sh = _ndims_from_shape(sh)
2221-
_is_array_shape(sh) || continue
2221+
is_array_shape(sh) || continue
22222222
ndims_sh <= 2 || throw_expected_matvec(shend)
22232223
is_matmatmul || throw_incompatible_shapes(cur_shape, sh)
22242224
is_matmatmul = ndims_sh != 1
@@ -2266,9 +2266,9 @@ end
22662266

22672267
function _split_arrterm_scalar_coeff(::Type{T}, ex::BasicSymbolic{T}) where {T}
22682268
sh = shape(ex)
2269-
_is_array_shape(sh) || return ex, one_of_vartype(T)
2269+
is_array_shape(sh) || return ex, one_of_vartype(T)
22702270
@match ex begin
2271-
BSImpl.Term(; f, args, type) && if f === (*) && !_is_array_shape(shape(first(args))) end => begin
2271+
BSImpl.Term(; f, args, type) && if f === (*) && !is_array_shape(shape(first(args))) end => begin
22722272
if length(args) == 2
22732273
return args[1], args[2]
22742274
end
@@ -2399,7 +2399,7 @@ function (mwb::MulWorkerBuffer{T})(terms) where {T}
23992399
end
24002400
sh = shape(term)
24012401
type = promote_symtype(*, type, symtype(term))
2402-
if _is_array_shape(sh)
2402+
if is_array_shape(sh)
24032403
coeff, arrterm = _split_arrterm_scalar_coeff(T, term)
24042404
_mul_worker!(T, num_coeff, den_coeff, num_dict, den_dict, coeff)
24052405
if iscall(arrterm) && operation(arrterm) === (*)
@@ -2730,7 +2730,7 @@ function _bslash_worker(::Type{T}, a, b) where {T}
27302730
sha = shape(a)
27312731
type = promote_symtype(\, symtype(a), symtype(b))
27322732
newshape = promote_shape(\, shape(a), shape(b))
2733-
if _is_array_shape(newshape) || _is_array_shape(sha)
2733+
if is_array_shape(newshape) || is_array_shape(sha)
27342734
# Scalar \ Anything == Anything / Scalar
27352735
return Term{T}(\, ArgsT{T}((a, b)); type, shape = newshape)
27362736
else
@@ -2828,7 +2828,7 @@ function ^(a::BasicSymbolic{T}, b) where {T <: Union{SymReal, SafeReal}}
28282828
newshape = promote_shape(^, sha, shb)
28292829
type = promote_symtype(^, symtype(a), symtype(b))
28302830

2831-
if _is_array_shape(sha)
2831+
if is_array_shape(sha)
28322832
@match a begin
28332833
BSImpl.Term(; f, args) && if f === (^) && isconst(args[1]) end => begin
28342834
base, exp = args
@@ -2843,7 +2843,7 @@ function ^(a::BasicSymbolic{T}, b) where {T <: Union{SymReal, SafeReal}}
28432843
end
28442844
_ => return Term{T}(^, ArgsT{T}((a, Const{T}(b))); type, shape = newshape)
28452845
end
2846-
elseif _is_array_shape(shb)
2846+
elseif is_array_shape(shb)
28472847
return Term{T}(^, ArgsT{T}((a, Const{T}(b))); type, shape = newshape)::BasicSymbolic{T}
28482848
end
28492849
if b isa Number
@@ -2907,7 +2907,7 @@ function ^(a::Union{Number, Matrix{<:Number}}, b::BasicSymbolic{T}) where {T}
29072907
isconst(b) && return Const{T}(a ^ unwrap_const(b))
29082908
newshape = promote_shape(^, shape(a), shape(b))
29092909
type = promote_symtype(^, symtype(a), symtype(b))
2910-
if _is_array_shape(newshape) && _isone(a)
2910+
if is_array_shape(newshape) && _isone(a)
29112911
if newshape isa Unknown
29122912
return Const{T}(LinearAlgebra.I)
29132913
else
@@ -2973,7 +2973,7 @@ function promote_shape(::typeof(getindex), sharr::ShapeT, shidxs::ShapeVecT...)
29732973
@nospecialize sharr
29742974
# `promote_symtype` rules out the presence of multidimensional indices - each index
29752975
# is either an integer, Colon or vector of integers.
2976-
_is_array_shape(sharr) || isempty(shidxs) || throw_not_array(sharr)
2976+
is_array_shape(sharr) || isempty(shidxs) || throw_not_array(sharr)
29772977
result = ShapeVecT()
29782978
for (i, idx) in enumerate(shidxs)
29792979
isempty(idx) && continue
@@ -3027,7 +3027,7 @@ Base.@propagate_inbounds function _getindex(arr::BasicSymbolic{T}, idxs::Union{B
30273027
idxs_i = 1
30283028
for oldidx in Iterators.drop(args, 1)
30293029
oldidx_sh = shape(oldidx)
3030-
if !_is_array_shape(oldidx_sh)
3030+
if !is_array_shape(oldidx_sh)
30313031
push!(newargs, oldidx)
30323032
continue
30333033
end
@@ -3111,7 +3111,7 @@ Base.@propagate_inbounds function _getindex(arr::BasicSymbolic{T}, idxs::Union{B
31113111
end
31123112
end
31133113
_ => begin
3114-
if _is_array_shape(newshape)
3114+
if is_array_shape(newshape)
31153115
new_output_idx = OutIdxT{T}()
31163116
expr_args = ArgsT{T}((arr,))
31173117
term_args = ArgsT{T}((arr,))

0 commit comments

Comments
 (0)