@@ -234,6 +234,8 @@ This mostly aims to re-work the given expression into `some(steps(A))[i,j]`,
234234but also pushes `A = f(x)` into `store.top`.
235235"""
236236function standardise (ex, store:: NamedTuple , call:: CallInfo ; LHS= false )
237+ @nospecialize ex
238+
237239 # This acts only on single indexing expressions:
238240 if @capture (ex, A_{ijk__})
239241 static= true
@@ -378,6 +380,7 @@ target dims not correctly handled yet -- what do I want? TODO
378380Simple glue / stand. does not permutedims, but broadcasting may have to... avoid twice?
379381"""
380382function standardglue (ex, target, store:: NamedTuple , call:: CallInfo )
383+ @nospecialize ex
381384
382385 # The sole target here is indexing expressions:
383386 if @capture (ex, A_[inner__])
@@ -469,6 +472,7 @@ This beings the expression to have target indices,
469472by permutedims and if necessary broadcasting, always using `readycast()`.
470473"""
471474function targetcast (ex, target, store:: NamedTuple , call:: CallInfo )
475+ @nospecialize ex
472476
473477 # If just one naked expression, then we won't broadcast:
474478 if @capture (ex, A_[ijk__])
503507This is walked over the expression to prepare for `@__dot__` etc, by `targetcast()`.
504508"""
505509function readycast (ex, target, store:: NamedTuple , call:: CallInfo )
510+ @nospecialize ex
506511
507512 # Scalar functions can be protected entirely from broadcasting:
508513 # TODO this means A[i,j] + rand()/10 doesn't work, /(...,10) is a function!
@@ -578,6 +583,7 @@ If there are more than two factors, it recurses, and you get `(A*B) * C`,
578583or perhaps tuple `(A*B, C)`.
579584"""
580585function matmultarget (ex, target, parsed, store:: NamedTuple , call:: CallInfo )
586+ @nospecialize ex
581587
582588 @capture (ex, A_ * B_ * C__ | * (A_, B_, C__) ) || throw (MacroError (" can't @matmul that!" , call))
583589
@@ -631,6 +637,7 @@ pushing calculation steps into store.
631637Also a convenient place to tidy all indices, including e.g. `fun(M[:,j],N[j]).same[i']`.
632638"""
633639function recursemacro (ex, store:: NamedTuple , call:: CallInfo )
640+ @nospecialize ex
634641
635642 # Actually look for recursion
636643 if @capture (ex, @reduce (subex__) )
@@ -675,6 +682,8 @@ This saves to `store` the sizes of all input tensors, and their sub-slices if an
675682 however it should not destroy this so that `sz_j` can be got later.
676683"""
677684function rightsizes (ex, store:: NamedTuple , call:: CallInfo )
685+ @nospecialize ex
686+
678687 :recurse in call. flags && return nothing # outer version took care of this
679688
680689 if @capture (ex, A_[outer__][inner__] | A_[outer__]{inner__} )
@@ -1115,8 +1124,7 @@ end
11151124
11161125tensorprimetidy (v:: Vector ) = Any[ tensorprimetidy (x) for x in v ]
11171126function tensorprimetidy (ex)
1118- MacroTools. postwalk (ex) do x
1119-
1127+ MacroTools. postwalk (ex) do @nospecialize x
11201128 @capture (x, ((ij__,) \ k_) ) && return :( ($ (ij... ),$ k) )
11211129 @capture (x, i_ \ j_ ) && return :( ($ i,$ j) )
11221130
@@ -1172,7 +1180,7 @@ containsindexing(s) = false
11721180function containsindexing (ex:: Expr )
11731181 flag = false
11741182 # MacroTools.postwalk(x -> @capture(x, A_[ijk__]) && (flag=true), ex)
1175- MacroTools. postwalk (ex) do x
1183+ MacroTools. postwalk (ex) do @nospecialize x
11761184 # @capture(x, A_[ijk__]) && !(all(isconstant, ijk)) && (flag=true)
11771185 if @capture (x, A_[ijk__])
11781186 # @show x ijk # TODO this is a bit broken? @pretty @cast Z[i,j] := W[i] * exp(X[1][i] - X[2][j])
@@ -1185,7 +1193,7 @@ end
11851193listindices (s:: Symbol ) = []
11861194function listindices (ex:: Expr )
11871195 list = []
1188- MacroTools. postwalk (ex) do x
1196+ MacroTools. postwalk (ex) do @nospecialize x
11891197 if @capture (x, A_[ijk__])
11901198 flat, _ = indexparse (nothing , ijk)
11911199 push! (list, flat)
0 commit comments