@@ -1080,9 +1080,9 @@ function findsizes(store::NamedTuple, call::CallInfo)
10801080 append! (out, store. assert)
10811081 empty! (store. assert)
10821082 if length (store. need) > 0
1083- sizes = sizeinfer (store, call)
1084- sz_list = map (axwrap, store. need)
1085- push! (out, :( local ($ (sz_list ... ),) = ($ (sizes ... ),) ) )
1083+ inferred = sizeinfer (store, call)
1084+ ax_list = map (axwrap, store. need)
1085+ push! (out, :( local ($ (ax_list ... ),) = ($ (inferred ... ),) ) )
10861086 end
10871087 append! (out, store. mustassert) # NB do this after calling sizeinfer()
10881088 unique! (out)
@@ -1107,22 +1107,24 @@ function sizeinfer(store::NamedTuple, call::CallInfo)
11071107 known = [ haskey (store. dict, j) for j in pair. first ]
11081108
11091109 if sum (.! known) == 1 # bingo! now work out its size:
1110- num = :( length ( $ ( pair. second)) )
1110+ num = takelength ( pair. second)
11111111
11121112 denfacts = [ store. dict[i] for i in pair. first[known] ]
11131113 if length (denfacts) > 1
1114- den = :( prod (length, ($ (denfacts... ),)) )
1114+ # den = :( prod(length, ($(denfacts...),)) )
1115+ longs = map (takelength, denfacts)
1116+ den = :( Base.:* ($ (longs... )) )
11151117 else
1116- den = :( length ( $ ( denfacts[1 ])) )
1118+ den = takelength ( denfacts[1 ])
11171119 end
1118- rat = :( Base. OneTo (div ( $ num, $ den) ) )
1120+ rat = :( Base. OneTo ($ num ÷ $ den) )
11191121
11201122 i = pair. first[.! known][1 ]
11211123 d = findfirst (isequal (i), store. need)
11221124 d != nothing && (sizes[d] = rat)
11231125
11241126 str = " expected integer multiples, when calculating range of $i from range of $(join (pair. first, " ⊗ " )) "
1125- push! (store. mustassert, :( rem ($ num, $ den)== 0 || throw (ArgumentError ($ str))) )
1127+ push! (store. mustassert, :( ($ num % $ den)== 0 || throw (ArgumentError ($ str))) )
11261128 end
11271129 end
11281130 end
@@ -1134,6 +1136,21 @@ function sizeinfer(store::NamedTuple, call::CallInfo)
11341136 return sizes
11351137end
11361138
1139+ """
1140+ takelength(OntTo(n)) -> n
1141+ takelength(axes(A,2)) -> size(A,2)
1142+ """
1143+ function takelength (ex)
1144+ if Meta. isexpr (ex, :call ) && ex. args[1 ] in (Base. OneTo, :(Base. OneTo))
1145+ ex. args[2 ]
1146+ elseif Meta. isexpr (ex, :call ) && ex. args[1 ] in (:axes , axes, :(Base. axes))
1147+ @assert length (ex. args) == 3
1148+ :(Base. size ($ (ex. args[2 : end ]. .. )))
1149+ else
1150+ :(Base. length ($ ex))
1151+ end
1152+ end
1153+
11371154"""
11381155 maybestaticsizes([:, :, i], (:,:,*)) -> (:,:,*)
11391156 maybestaticsizes([:3, :4, i], (:,:,*)) -> Size(3,4)
0 commit comments