Skip to content

Commit 5523233

Browse files
committed
prettier range inference, avoid length(OneTo(...))
1 parent cf14dea commit 5523233

File tree

1 file changed

+25
-8
lines changed

1 file changed

+25
-8
lines changed

src/macro.jl

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1080,9 +1080,9 @@ function findsizes(store::NamedTuple, call::CallInfo)
10801080
append!(out, store.assert)
10811081
empty!(store.assert)
10821082
if length(store.need) > 0
1083-
sizes = sizeinfer(store, call)
1084-
sz_list = map(axwrap, store.need)
1085-
push!(out, :( local ($(sz_list...),) = ($(sizes...),) ) )
1083+
inferred = sizeinfer(store, call)
1084+
ax_list = map(axwrap, store.need)
1085+
push!(out, :( local ($(ax_list...),) = ($(inferred...),) ) )
10861086
end
10871087
append!(out, store.mustassert) # NB do this after calling sizeinfer()
10881088
unique!(out)
@@ -1107,22 +1107,24 @@ function sizeinfer(store::NamedTuple, call::CallInfo)
11071107
known = [ haskey(store.dict, j) for j in pair.first ]
11081108

11091109
if sum(.!known) == 1 # bingo! now work out its size:
1110-
num = :(length($(pair.second)))
1110+
num = takelength(pair.second)
11111111

11121112
denfacts = [ store.dict[i] for i in pair.first[known] ]
11131113
if length(denfacts) > 1
1114-
den = :( prod(length, ($(denfacts...),)) )
1114+
# den = :( prod(length, ($(denfacts...),)) )
1115+
longs = map(takelength, denfacts)
1116+
den = :( Base.:*($(longs...)) )
11151117
else
1116-
den = :( length($(denfacts[1])) )
1118+
den = takelength(denfacts[1])
11171119
end
1118-
rat = :( Base.OneTo(div($num, $den)) )
1120+
rat = :( Base.OneTo($num ÷ $den) )
11191121

11201122
i = pair.first[.!known][1]
11211123
d = findfirst(isequal(i), store.need)
11221124
d != nothing && (sizes[d] = rat)
11231125

11241126
str = "expected integer multiples, when calculating range of $i from range of $(join(pair.first, ""))"
1125-
push!(store.mustassert, :( rem($num, $den)==0 || throw(ArgumentError($str))) )
1127+
push!(store.mustassert, :( ($num % $den)==0 || throw(ArgumentError($str))) )
11261128
end
11271129
end
11281130
end
@@ -1134,6 +1136,21 @@ function sizeinfer(store::NamedTuple, call::CallInfo)
11341136
return sizes
11351137
end
11361138

1139+
"""
1140+
takelength(OntTo(n)) -> n
1141+
takelength(axes(A,2)) -> size(A,2)
1142+
"""
1143+
function takelength(ex)
1144+
if Meta.isexpr(ex, :call) && ex.args[1] in (Base.OneTo, :(Base.OneTo))
1145+
ex.args[2]
1146+
elseif Meta.isexpr(ex, :call) && ex.args[1] in (:axes, axes, :(Base.axes))
1147+
@assert length(ex.args) == 3
1148+
:(Base.size($(ex.args[2:end]...)))
1149+
else
1150+
:(Base.length($ex))
1151+
end
1152+
end
1153+
11371154
"""
11381155
maybestaticsizes([:, :, i], (:,:,*)) -> (:,:,*)
11391156
maybestaticsizes([:3, :4, i], (:,:,*)) -> Size(3,4)

0 commit comments

Comments
 (0)