Skip to content

Commit c9c1fb1

Browse files
author
Michael Abbott
committed
add some nospecialize annotations
1 parent 3f8a0e7 commit c9c1fb1

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

src/macro.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,8 @@ This mostly aims to re-work the given expression into `some(steps(A))[i,j]`,
234234
but also pushes `A = f(x)` into `store.top`.
235235
"""
236236
function standardise(ex, store::NamedTuple, call::CallInfo; LHS=false)
237+
@nospecialize ex
238+
237239
# This acts only on single indexing expressions:
238240
if @capture(ex, A_{ijk__})
239241
static=true
@@ -378,6 +380,7 @@ target dims not correctly handled yet -- what do I want? TODO
378380
Simple glue / stand. does not permutedims, but broadcasting may have to... avoid twice?
379381
"""
380382
function standardglue(ex, target, store::NamedTuple, call::CallInfo)
383+
@nospecialize ex
381384

382385
# The sole target here is indexing expressions:
383386
if @capture(ex, A_[inner__])
@@ -469,6 +472,7 @@ This beings the expression to have target indices,
469472
by permutedims and if necessary broadcasting, always using `readycast()`.
470473
"""
471474
function targetcast(ex, target, store::NamedTuple, call::CallInfo)
475+
@nospecialize ex
472476

473477
# If just one naked expression, then we won't broadcast:
474478
if @capture(ex, A_[ijk__])
@@ -503,6 +507,7 @@ end
503507
This is walked over the expression to prepare for `@__dot__` etc, by `targetcast()`.
504508
"""
505509
function readycast(ex, target, store::NamedTuple, call::CallInfo)
510+
@nospecialize ex
506511

507512
# Scalar functions can be protected entirely from broadcasting:
508513
# TODO this means A[i,j] + rand()/10 doesn't work, /(...,10) is a function!
@@ -578,6 +583,7 @@ If there are more than two factors, it recurses, and you get `(A*B) * C`,
578583
or perhaps tuple `(A*B, C)`.
579584
"""
580585
function matmultarget(ex, target, parsed, store::NamedTuple, call::CallInfo)
586+
@nospecialize ex
581587

582588
@capture(ex, A_ * B_ * C__ | *(A_, B_, C__) ) || throw(MacroError("can't @matmul that!", call))
583589

@@ -631,6 +637,7 @@ pushing calculation steps into store.
631637
Also a convenient place to tidy all indices, including e.g. `fun(M[:,j],N[j]).same[i']`.
632638
"""
633639
function recursemacro(ex, store::NamedTuple, call::CallInfo)
640+
@nospecialize ex
634641

635642
# Actually look for recursion
636643
if @capture(ex, @reduce(subex__) )
@@ -675,6 +682,8 @@ This saves to `store` the sizes of all input tensors, and their sub-slices if an
675682
however it should not destroy this so that `sz_j` can be got later.
676683
"""
677684
function rightsizes(ex, store::NamedTuple, call::CallInfo)
685+
@nospecialize ex
686+
678687
:recurse in call.flags && return nothing # outer version took care of this
679688

680689
if @capture(ex, A_[outer__][inner__] | A_[outer__]{inner__} )
@@ -1115,8 +1124,7 @@ end
11151124

11161125
tensorprimetidy(v::Vector) = Any[ tensorprimetidy(x) for x in v ]
11171126
function tensorprimetidy(ex)
1118-
MacroTools.postwalk(ex) do x
1119-
1127+
MacroTools.postwalk(ex) do @nospecialize x
11201128
@capture(x, ((ij__,) \ k_) ) && return :( ($(ij...),$k) )
11211129
@capture(x, i_ \ j_ ) && return :( ($i,$j) )
11221130

@@ -1172,7 +1180,7 @@ containsindexing(s) = false
11721180
function containsindexing(ex::Expr)
11731181
flag = false
11741182
# MacroTools.postwalk(x -> @capture(x, A_[ijk__]) && (flag=true), ex)
1175-
MacroTools.postwalk(ex) do x
1183+
MacroTools.postwalk(ex) do @nospecialize x
11761184
# @capture(x, A_[ijk__]) && !(all(isconstant, ijk)) && (flag=true)
11771185
if @capture(x, A_[ijk__])
11781186
# @show x ijk # TODO this is a bit broken? @pretty @cast Z[i,j] := W[i] * exp(X[1][i] - X[2][j])
@@ -1185,7 +1193,7 @@ end
11851193
listindices(s::Symbol) = []
11861194
function listindices(ex::Expr)
11871195
list = []
1188-
MacroTools.postwalk(ex) do x
1196+
MacroTools.postwalk(ex) do @nospecialize x
11891197
if @capture(x, A_[ijk__])
11901198
flat, _ = indexparse(nothing, ijk)
11911199
push!(list, flat)

0 commit comments

Comments
 (0)