From db3deab7afc04011beef62d0d3e9607b8e8648fd Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Fri, 20 Mar 2020 18:22:49 +0100 Subject: [PATCH 1/4] faster capture macro for one pattern --- src/TensorCast.jl | 1 + src/capture.jl | 88 +++++++++++++++++++++++++++++++++++++++++++++++ src/macro.jl | 24 ++++++------- 3 files changed, 101 insertions(+), 12 deletions(-) create mode 100644 src/capture.jl diff --git a/src/TensorCast.jl b/src/TensorCast.jl index ded64a6..e30be5b 100644 --- a/src/TensorCast.jl +++ b/src/TensorCast.jl @@ -6,6 +6,7 @@ export @cast, @reduce, @matmul, @pretty using MacroTools, StaticArrays, Compat using LinearAlgebra, Random +include("capture.jl") include("macro.jl") include("pretty.jl") include("string.jl") diff --git a/src/capture.jl b/src/capture.jl new file mode 100644 index 0000000..e506a34 --- /dev/null +++ b/src/capture.jl @@ -0,0 +1,88 @@ +""" + @capture_(ex, A_[ijk__]) + +A faster drop-in replacement for `MacroTools.@capture`, for this particular pattern only. +""" +macro capture_(ex, pat::Expr) + pat.head == :ref && + length(pat.args)==2 && + endswith(string(pat.args[1]), '_') && + endswith(string(pat.args[2]), "__") || error("@capture_ only works on pattern A_[ijk__]") + + A = Symbol(string(pat.args[1])[1:end-1]) + ijk = Symbol(string(pat.args[2])[1:end-2]) + @gensym res + quote + $A, $ijk = nothing, nothing + $res = TensorCast._trymatch($ex) + if $res == nothing + false + else + $A, $ijk = $res + true + end + end |> esc +end + +_trymatch(s) = nothing +function _trymatch(ex::Expr) + ex.head == :ref || return nothing + ex.args[1], ex.args[2:end] +end + +#= + +julia> ex = :(Z[1,2,3]) + +julia> @pretty @capture(ex, A_[ijk__]) +begin + A = MacroTools.nothing + ijk = MacroTools.nothing + tarsier = trymatch($(Expr(:copyast, :($(QuoteNode(:(A_[ijk__])))))), ex) + if tarsier == MacroTools.nothing + false + else + A = get(tarsier, :A, MacroTools.nothing) + ijk = get(tarsier, :ijk, MacroTools.nothing) + true + end +end + +julia> @pretty @capture_(ex, A_[ijk__]) +begin + A = nothing + ijk = nothing + louse = _trymatch(ex) + if louse == nothing + false + else + (A, ijk) = louse + true + end +end + + + +ex = :( A[i,j][k] + B[I[i],J[j],k]^2 / 2 ) +f1(x) = MacroTools.postwalk(ex) do x + @capture(x, A_[ijk__]) || return x + :($A[$(ijk...),9]) + end +f2(x) = MacroTools.postwalk(ex) do x + @capture_(x, A_[ijk__]) || return x + :($A[$(ijk...),9]) + end +f1(ex) +f2(ex) + +@btime f1(x) setup=(x=ex) # 3.181 ms +@btime f2(x) setup=(x=ex) # 31.440 μs -- 100x faster. + + +$ time julia -e 'using TensorCast; TensorCast._macro(:( Z[i,k][j] := fun(A[i,:], B[j])[k] + C[k]^2 ))' + +real 0m8.567s # was 0m9.485s +user 0m8.295s +sys 0m0.329s + +=# diff --git a/src/macro.jl b/src/macro.jl index eb7ed2f..9aea1fd 100644 --- a/src/macro.jl +++ b/src/macro.jl @@ -238,7 +238,7 @@ function standardise(ex, store::NamedTuple, call::CallInfo; LHS=false) if @capture(ex, A_{ijk__}) static=true push!(call.flags, :staticslice) - elseif @capture(ex, A_[ijk__]) + elseif @capture_(ex, A_[ijk__]) static=false else return ex @@ -380,7 +380,7 @@ Simple glue / stand. does not permutedims, but broadcasting may have to... avoid function standardglue(ex, target, store::NamedTuple, call::CallInfo) # The sole target here is indexing expressions: - if @capture(ex, A_[inner__]) + if @capture_(ex, A_[inner__]) static=false elseif @capture(ex, A_{inner__}) static=true @@ -471,7 +471,7 @@ by permutedims and if necessary broadcasting, always using `readycast()`. function targetcast(ex, target, store::NamedTuple, call::CallInfo) # If just one naked expression, then we won't broadcast: - if @capture(ex, A_[ijk__]) + if @capture_(ex, A_[ijk__]) containsindexing(A) && error("that should have been dealt with") return readycast(ex, target, store, call) end @@ -536,7 +536,7 @@ function readycast(ex, target, store::NamedTuple, call::CallInfo) return :( Core._apply($funs[$(ijk...)], $(args...) ) ) # Apart from those, readycast acts only on lone tensors: - @capture(ex, A_[ijk__]) || return ex + @capture_(ex, A_[ijk__]) || return ex dims = Int[ findcheck(i, target, call, " on the left") for i in ijk ] @@ -656,7 +656,7 @@ function recursemacro(ex, store::NamedTuple, call::CallInfo) end # Tidy up indices, A[i,j][k] will be hit on different rounds... - if @capture(ex, A_[ijk__]) + if @capture_(ex, A_[ijk__]) return :( $A[$(tensorprimetidy(ijk)...)] ) elseif @capture(ex, A_{ijk__}) return :( $A{$(tensorprimetidy(ijk)...)} ) @@ -875,7 +875,7 @@ function indexparse(A, ijk::Vector, store=nothing, call=nothing; save=false) push!(outsize, szwrap(ii)) save && saveonesize(ii, :(size($A, $d)), store) - elseif @capture(i, B_[klm__]) + elseif @capture_(i, B_[klm__]) innerparse(B, klm, store, call) # called just for error on tensor/colon/constant sub = indexparse(B, klm, store, call; save=save) # I do want to save size(B,1) etc. append!(flat, sub.flat) @@ -1143,7 +1143,7 @@ isconstant(ex::Expr) = ex.head == :($) isconstant(q::QuoteNode) = false isindexing(s) = false -isindexing(ex::Expr) = @capture(x, A_[ijk__]) +isindexing(ex::Expr) = @capture_(x, A_[ijk__]) isCorI(i) = isconstant(i) || isindexing(ii) @@ -1171,10 +1171,10 @@ iscolon(q::QuoteNode) = true containsindexing(s) = false function containsindexing(ex::Expr) flag = false - # MacroTools.postwalk(x -> @capture(x, A_[ijk__]) && (flag=true), ex) + # MacroTools.postwalk(x -> @capture_(x, A_[ijk__]) && (flag=true), ex) MacroTools.postwalk(ex) do x - # @capture(x, A_[ijk__]) && !(all(isconstant, ijk)) && (flag=true) - if @capture(x, A_[ijk__]) + # @capture_(x, A_[ijk__]) && !(all(isconstant, ijk)) && (flag=true) + if @capture_(x, A_[ijk__]) # @show x ijk # TODO this is a bit broken? @pretty @cast Z[i,j] := W[i] * exp(X[1][i] - X[2][j]) flag=true end @@ -1186,7 +1186,7 @@ listindices(s::Symbol) = [] function listindices(ex::Expr) list = [] MacroTools.postwalk(ex) do x - if @capture(x, A_[ijk__]) + if @capture_(x, A_[ijk__]) flat, _ = indexparse(nothing, ijk) push!(list, flat) end @@ -1445,7 +1445,7 @@ function inplaceoutput(ex, canon, parsed, store::NamedTuple, call::CallInfo) push!(store.mustassert, :( TensorCast.@assert_ ndims($zed)==0 $str) ) else newleft = standardise(parsed.left, store, call) - @capture(newleft, zed_[ijk__]) || throw(MacroError("failed to parse LHS correctly, $(parsed.left) -> $newleft")) + @capture_(newleft, zed_[ijk__]) || throw(MacroError("failed to parse LHS correctly, $(parsed.left) -> $newleft")) if !(zed isa Symbol) # then standardise did something! push!(call.flags, :showfinal) From ba8523973e344f509913b9dd67f7680b9d4c048c Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Sun, 29 Mar 2020 21:03:06 +0200 Subject: [PATCH 2/4] ... for two patterns --- src/capture.jl | 70 +++++++++++++++++++++++++++++++++++++++++++------- src/macro.jl | 20 +++++++++------ 2 files changed, 73 insertions(+), 17 deletions(-) diff --git a/src/capture.jl b/src/capture.jl index e506a34..f2be1e1 100644 --- a/src/capture.jl +++ b/src/capture.jl @@ -4,17 +4,23 @@ A faster drop-in replacement for `MacroTools.@capture`, for this particular pattern only. """ macro capture_(ex, pat::Expr) - pat.head == :ref && + + pat.head in [:ref, :curly] && length(pat.args)==2 && endswith(string(pat.args[1]), '_') && - endswith(string(pat.args[2]), "__") || error("@capture_ only works on pattern A_[ijk__]") + endswith(string(pat.args[2]), "__") || + error("@capture_ doesn't work on pattern $pat") A = Symbol(string(pat.args[1])[1:end-1]) ijk = Symbol(string(pat.args[2])[1:end-2]) + + qn = QuoteNode(pat.head) + @gensym res quote $A, $ijk = nothing, nothing - $res = TensorCast._trymatch($ex) + # $res = TensorCast._trymatch($ex, Val($qn)) + $res = _trymatch($ex, Val($qn)) if $res == nothing false else @@ -24,12 +30,54 @@ macro capture_(ex, pat::Expr) end |> esc end -_trymatch(s) = nothing -function _trymatch(ex::Expr) - ex.head == :ref || return nothing - ex.args[1], ex.args[2:end] -end +_trymatch(s, v) = nothing # s::Symbol +_trymatch(ex::Expr, ::Val{:ref}) = # A_[ijk__] + if ex.head === :ref + ex.args[1], ex.args[2:end] + else + nothing + end +_trymatch(ex::Expr, ::Val{:curly}) = # A_{ijk__} + if ex.head === :curly + ex.args[1], ex.args[2:end] + else + nothing + end + + # elseif pat.head == :call && pat.args[1] === :| + # length(pat.args)==3 || error("@capture_ doesn't work on pattern $pat") # Or syntax + # for i in 2:3 + # pat.args[i].head in [:ref, :curly] && + # length(pat.args[i].args)==2 && + # endswith(string(pat.args[i].args[1]), '_') && + # endswith(string(pat.args[i].args[2]), "__") || + # error("@capture_ doesn't work on pattern $pat") + + # A = Symbol(string(pat.args[i].args[1])[1:end-1]) + # ijk = Symbol(string(pat.args[i].args[1])[1:end-2]) + # end + # end +# _trymatch(ex::Expr, ::Val{:call}) = # A | B +# if ex.head === :call && ex.args[1] === :| +# ex.args[1], ex.args[2:end] +# else +# nothing +# end + + +# elseif pat.head === :call && pat.args[1] === :| # Or syntax +# left = _trymatch(pat.args[2], ex) +# if left !== nothing +# return left +# else +# return right = _trymatch(pat.args[3], ex) +# end +# end +# nothing +# end + +# || pat.head === :curly # Patten is A_[ijk__] or A_{ijk__} #= julia> ex = :(Z[1,2,3]) @@ -81,8 +129,12 @@ f2(ex) $ time julia -e 'using TensorCast; TensorCast._macro(:( Z[i,k][j] := fun(A[i,:], B[j])[k] + C[k]^2 ))' -real 0m8.567s # was 0m9.485s +real 0m8.828s # was 0m9.485s user 0m8.295s sys 0m0.329s +$ time julia -e 'using TensorCast; @time TensorCast._macro(:( Z[i,k][j] := fun(A[i,:], B[j])[k] + C[k]^2 ))' + +5.194806 seconds + =# diff --git a/src/macro.jl b/src/macro.jl index 9aea1fd..a087e1b 100644 --- a/src/macro.jl +++ b/src/macro.jl @@ -235,7 +235,7 @@ but also pushes `A = f(x)` into `store.top`. """ function standardise(ex, store::NamedTuple, call::CallInfo; LHS=false) # This acts only on single indexing expressions: - if @capture(ex, A_{ijk__}) + if @capture_(ex, A_{ijk__}) static=true push!(call.flags, :staticslice) elseif @capture_(ex, A_[ijk__]) @@ -285,11 +285,11 @@ function standardise(ex, store::NamedTuple, call::CallInfo; LHS=false) end # Nested indices A[i,j,B[k,l,m],n] or worse A[i,B[j,k],C[i,j]] - if any(i -> @capture(i, B_[klm__]), ijk) + if any(i -> @capture_(i, B_[klm__]), ijk) newijk, beecolon = [], [] # for simple case # listB, listijk = [], [] for i in ijk - if @capture(i, B_[klm__]) + if @capture_(i, B_[klm__]) append!(newijk, klm) push!(beecolon, B) # push!(listijk, klm) @@ -382,7 +382,7 @@ function standardglue(ex, target, store::NamedTuple, call::CallInfo) # The sole target here is indexing expressions: if @capture_(ex, A_[inner__]) static=false - elseif @capture(ex, A_{inner__}) + elseif @capture_(ex, A_{inner__}) static=true else return ex @@ -394,7 +394,7 @@ function standardglue(ex, target, store::NamedTuple, call::CallInfo) end # Otherwise there are two options, (brodcasting...)[k] or simple B[i,j][k] - needcast = !@capture(A, B_[outer__]) + needcast = !@capture_(A, B_[outer__]) if needcast outer = unique(reduce(vcat, listindices(A))) @@ -503,6 +503,7 @@ end This is walked over the expression to prepare for `@__dot__` etc, by `targetcast()`. """ function readycast(ex, target, store::NamedTuple, call::CallInfo) + ex isa Symbol && return ex # quit early? # Scalar functions can be protected entirely from broadcasting: # TODO this means A[i,j] + rand()/10 doesn't work, /(...,10) is a function! @@ -631,6 +632,7 @@ pushing calculation steps into store. Also a convenient place to tidy all indices, including e.g. `fun(M[:,j],N[j]).same[i']`. """ function recursemacro(ex, store::NamedTuple, call::CallInfo) + ex isa Symbol && return ex # quit early? # Actually look for recursion if @capture(ex, @reduce(subex__) ) @@ -658,7 +660,7 @@ function recursemacro(ex, store::NamedTuple, call::CallInfo) # Tidy up indices, A[i,j][k] will be hit on different rounds... if @capture_(ex, A_[ijk__]) return :( $A[$(tensorprimetidy(ijk)...)] ) - elseif @capture(ex, A_{ijk__}) + elseif @capture_(ex, A_{ijk__}) return :( $A{$(tensorprimetidy(ijk)...)} ) else return ex @@ -680,7 +682,8 @@ function rightsizes(ex, store::NamedTuple, call::CallInfo) if @capture(ex, A_[outer__][inner__] | A_[outer__]{inner__} ) field = nothing elseif @capture(ex, A_[outer__].field_[inner__] | A_[outer__].field_{inner__} ) - elseif @capture(ex, A_[outer__] | A_{outer__} ) + # elseif @capture(ex, A_[outer__] | A_{outer__} ) + elseif @capture_(ex, A_[outer__] ) || @capture_(ex, A_{outer__} ) field = nothing else return ex @@ -753,7 +756,7 @@ function castparse(ex, store::NamedTuple, call::CallInfo; reduce=false) error("wtf is $ex") end - static = @capture(left, ZZ_{ii__}) + static = @capture_(left, ZZ_{ii__}) if @capture(left, Z_[outer__][inner__] | [outer__][inner__] | Z_[outer__]{inner__} | [outer__]{inner__} ) isnothing(Z) && (:inplace in call.flags) && throw(MacroError("can't write into a nameless tensor", call)) @@ -1116,6 +1119,7 @@ end tensorprimetidy(v::Vector) = Any[ tensorprimetidy(x) for x in v ] function tensorprimetidy(ex) MacroTools.postwalk(ex) do x + x isa Symbol && return x # quit early? @capture(x, ((ij__,) \ k_) ) && return :( ($(ij...),$k) ) @capture(x, i_ \ j_ ) && return :( ($i,$j) ) From b3eee4050db4299edd8c76a886fabf84b569457b Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Mon, 6 Apr 2020 16:04:55 +0200 Subject: [PATCH 3/4] add more cases to macro --- src/capture.jl | 128 +++++++++++++++++++++++++++-------------------- src/macro.jl | 22 ++++---- test/runtests.jl | 1 + test/two.jl | 33 ++++++++++++ 4 files changed, 118 insertions(+), 66 deletions(-) diff --git a/src/capture.jl b/src/capture.jl index f2be1e1..1f3bd96 100644 --- a/src/capture.jl +++ b/src/capture.jl @@ -1,83 +1,98 @@ """ @capture_(ex, A_[ijk__]) -A faster drop-in replacement for `MacroTools.@capture`, for this particular pattern only. +Faster drop-in replacement for `MacroTools.@capture`, for a few patterns only: +* `A_[ijk__]` and `A_{ijk__}` +* `[ijk__]` +* `A_.field_` +* `A_ := B_` and `A_ = B_` and `A_ += B_` etc. +* `f_(x_)` """ macro capture_(ex, pat::Expr) - pat.head in [:ref, :curly] && - length(pat.args)==2 && - endswith(string(pat.args[1]), '_') && - endswith(string(pat.args[2]), "__") || - error("@capture_ doesn't work on pattern $pat") + H = QuoteNode(pat.head) + + A,B = if pat.head in [:ref, :curly] && length(pat.args)==2 && + _endswithone(pat.args[1]) && _endswithtwo(pat.args[2]) # :( A_[ijk__] ) + _symbolone(pat.args[1]), _symboltwo(pat.args[2]) + + elseif pat.head == :. && + _endswithone(pat.args[1]) && _endswithone(pat.args[2].value) # :( A_.field_ ) + _symbolone(pat.args[1]), _symbolone(pat.args[2].value) - A = Symbol(string(pat.args[1])[1:end-1]) - ijk = Symbol(string(pat.args[2])[1:end-2]) + elseif pat.head == :call && length(pat.args)==2 && + _endswithone(pat.args[1]) && _endswithone(pat.args[2]) # :( f_(x_) ) + _symbolone(pat.args[1]), _symbolone(pat.args[2]) - qn = QuoteNode(pat.head) + elseif pat.head in [:call, :(=), :(:=), :+=, :-=, :*=, :/=] && + _endswithone(pat.args[1]) && _endswithone(pat.args[2]) # :( A_ += B_ ) + _symbolone(pat.args[1]), _symbolone(pat.args[2]) + + elseif pat.head == :vect && _endswithtwo(pat.args[1]) # :( [ijk__] ) + _symboltwo(pat.args[1]), gensym(:ignore) + + else + error("@capture_ doesn't work on pattern $pat") + end @gensym res quote - $A, $ijk = nothing, nothing - # $res = TensorCast._trymatch($ex, Val($qn)) - $res = _trymatch($ex, Val($qn)) - if $res == nothing + $A, $B = nothing, nothing + $res = TensorCast._trymatch($ex, Val($H)) + # $res = _trymatch($ex, Val($H)) + if $res === nothing false else - $A, $ijk = $res + $A, $B = $res true end end |> esc end -_trymatch(s, v) = nothing # s::Symbol -_trymatch(ex::Expr, ::Val{:ref}) = # A_[ijk__] - if ex.head === :ref +_endswithone(ex) = endswith(string(ex), '_') && !_endswithtwo(ex) +_endswithtwo(ex) = endswith(string(ex), "__") + +_symbolone(ex) = Symbol(string(ex)[1:end-1]) +_symboltwo(ex) = Symbol(string(ex)[1:end-2]) + +_getvalue(::Val{val}) where {val} = val + +_trymatch(s, v) = nothing # Symbol, or other Expr +_trymatch(ex::Expr, pat::Union{Val{:ref}, Val{:curly}}) = # A_[ijk__] or A_{ijk__} + if ex.head === _getvalue(pat) ex.args[1], ex.args[2:end] else nothing end -_trymatch(ex::Expr, ::Val{:curly}) = # A_{ijk__} - if ex.head === :curly - ex.args[1], ex.args[2:end] +_trymatch(ex::Expr, ::Val{:.}) = # A_.field_ + if ex.head === :. + ex.args[1], ex.args[2].value + else + nothing + end +_trymatch(ex::Expr, pat::Val{:call}) = + if ex.head === _getvalue(pat) && length(ex.args) == 2 + ex.args[1], ex.args[2] + else + nothing + end +_trymatch(ex::Expr, pat::Union{Val{:(=)}, Val{:(:=)}, Val{:(+=)}, Val{:(-=)}, Val{:(*=)}, Val{:(/=)}}) = + if ex.head === _getvalue(pat) + ex.args[1], ex.args[2] + else + nothing + end +_trymatch(ex::Expr, ::Val{:vect}) = # [ijk__] + if ex.head === :vect + ex.args, nothing else nothing end +# Cases for Tullio: +# @capture(ex, B_[inds__].field_) --> @capture_(ex, Binds_.field_) && @capture_(Binds, B_[inds__]) +# @capture(ex, [inds__]) - # elseif pat.head == :call && pat.args[1] === :| - # length(pat.args)==3 || error("@capture_ doesn't work on pattern $pat") # Or syntax - # for i in 2:3 - # pat.args[i].head in [:ref, :curly] && - # length(pat.args[i].args)==2 && - # endswith(string(pat.args[i].args[1]), '_') && - # endswith(string(pat.args[i].args[2]), "__") || - # error("@capture_ doesn't work on pattern $pat") - - # A = Symbol(string(pat.args[i].args[1])[1:end-1]) - # ijk = Symbol(string(pat.args[i].args[1])[1:end-2]) - # end - # end -# _trymatch(ex::Expr, ::Val{:call}) = # A | B -# if ex.head === :call && ex.args[1] === :| -# ex.args[1], ex.args[2:end] -# else -# nothing -# end - - -# elseif pat.head === :call && pat.args[1] === :| # Or syntax -# left = _trymatch(pat.args[2], ex) -# if left !== nothing -# return left -# else -# return right = _trymatch(pat.args[3], ex) -# end -# end -# nothing -# end - -# || pat.head === :curly # Patten is A_[ijk__] or A_{ijk__} #= julia> ex = :(Z[1,2,3]) @@ -129,12 +144,15 @@ f2(ex) $ time julia -e 'using TensorCast; TensorCast._macro(:( Z[i,k][j] := fun(A[i,:], B[j])[k] + C[k]^2 ))' -real 0m8.828s # was 0m9.485s +real 0m8.132s # was 0m8.900s on master, noise or signal? +user 0m7.747s +sys 0m0.358s +real 0m8.132s user 0m8.295s sys 0m0.329s $ time julia -e 'using TensorCast; @time TensorCast._macro(:( Z[i,k][j] := fun(A[i,:], B[j])[k] + C[k]^2 ))' -5.194806 seconds +4.899634 seconds, best run # was 5.845 on master, that's a second? =# diff --git a/src/macro.jl b/src/macro.jl index a087e1b..3277efd 100644 --- a/src/macro.jl +++ b/src/macro.jl @@ -245,7 +245,7 @@ function standardise(ex, store::NamedTuple, call::CallInfo; LHS=false) end # Ensure that f(x)[i,j] will evaluate once, including in size(A) - if A isa Symbol || @capture(A, AA_.ff_) # caller has ensured !containsindexing(A) + if A isa Symbol || @capture_(A, AA_.ff_) # caller has ensured !containsindexing(A) else Asym = Symbol(A,"_val") # exact same symbol is used by rightsizes() push!(store.top, :( local $Asym = $A ) ) @@ -521,10 +521,10 @@ function readycast(ex, target, store::NamedTuple, call::CallInfo) return :( getproperty($fun($(arg...)), $(QuoteNode(field))) ) # tuple creation... now including namedtuples @capture(ex, (args__,) ) && any(containsindexing, args) && - if any(a -> @capture(a, sym_ = val_), args) + if any(a -> @capture_(a, sym_ = val_), args) syms, vals = [], [] map(args) do a - @capture(a, sym_ = val_ ) || throw(MacroError("invalid named tuple element $a", call)) + @capture_(a, sym_ = val_ ) || throw(MacroError("invalid named tuple element $a", call)) push!(syms, QuoteNode(sym)) push!(vals, val) end @@ -690,13 +690,13 @@ function rightsizes(ex, store::NamedTuple, call::CallInfo) end # Special treatment for fun(x)[i,j], goldilocks A not just symbol, but no indexing - if A isa Symbol || @capture(A, AA_.ff_) + if A isa Symbol || @capture_(A, AA_.ff_) elseif !containsindexing(A) A = Symbol(A,"_val") # the exact same symbol is used by standardiser end # When we can save the sizes, then we destroy so as not to save again: - if A isa Symbol || @capture(A, AA_.ff_) && !containsindexing(A) + if A isa Symbol || @capture_(A, AA_.ff_) && !containsindexing(A) indexparse(A, outer, store, call; save=true) if field==nothing innerparse(:(first($A)), inner, store, call; save=true) @@ -722,24 +722,24 @@ function castparse(ex, store::NamedTuple, call::CallInfo; reduce=false) Z = gensym(:left) # Do we make a new array? With or without collecting: - if @capture(ex, left_ := right_ ) + if @capture_(ex, left_ := right_ ) elseif @capture(ex, left_ == right_ ) @warn "using == no longer does anything" call.string maxlog=1 _id=hash(call.string) elseif @capture(ex, left_ |= right_ ) push!(call.flags, :collect) # Do we write into an exising array? Possibly updating it: - elseif @capture(ex, left_ = right_ ) + elseif @capture_(ex, left_ = right_ ) push!(call.flags, :inplace) - elseif @capture(ex, left_ += right_ ) + elseif @capture_(ex, left_ += right_ ) push!(call.flags, :inplace) right = :( $left + $right ) reduce && throw(MacroError("can't use += with @reduce", call)) - elseif @capture(ex, left_ -= right_ ) + elseif @capture_(ex, left_ -= right_ ) push!(call.flags, :inplace) right = :( $left - ($right) ) reduce && throw(MacroError("can't use -= with @reduce", call)) - elseif @capture(ex, left_ *= right_ ) + elseif @capture_(ex, left_ *= right_ ) push!(call.flags, :inplace) right = :( $left * ($right) ) reduce && throw(MacroError("can't use *= with @reduce", call)) @@ -1443,7 +1443,7 @@ function inplaceoutput(ex, canon, parsed, store::NamedTuple, call::CallInfo) pop!(call.flags, :nolazy, :ok) # ensure we use diagview(), Reverse{}, etc, not a copy if @capture(parsed.left, zed_[]) # special case Z[] = ... else allconst pulls it out - zed isa Symbol || @capture(zed, ZZ_.field_) || error("wtf") + zed isa Symbol || @capture_(zed, ZZ_.field_) || error("wtf") newleft = parsed.left str = "expected a 0-tensor $zed[]" push!(store.mustassert, :( TensorCast.@assert_ ndims($zed)==0 $str) ) diff --git a/test/runtests.jl b/test/runtests.jl index 22bf376..04b7df1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,6 +10,7 @@ using Compat if VERSION >= v"1.1" using LoopVectorization end +using TensorCast: @capture_ @testset "ex-@shape" begin include("shape.jl") end @testset "@reduce" begin include("reduce.jl") end diff --git a/test/two.jl b/test/two.jl index 0992c32..b6bc6e3 100644 --- a/test/two.jl +++ b/test/two.jl @@ -356,3 +356,36 @@ end # @test_throws DimensionMismatch @cast M5[i,j] := fun(M[:,j]).same[i] i:99, j:4 # TODO make this check canonical length? end +@testset "capture_ macro" begin + + using TensorCast: @capture_ + + EXS = [:(A[i,j,k]), :(B{i,2,:}), :(C.dee), :(fun(5)), :(g := h+i), :(k[3] += l[4]), :([m,n,0]) ] + PATS = [:(A_[ijk__]), :(B_{ind__}), :(C_.d_), :(f_(arg_)), :(left_ := right_), :(a_ += b_), :([emm__]) ] + # @test length(EXS) == length(PATS) + @testset "ex = $(EXS[i])" for i in eachindex(EXS) + for j in eachindex(PATS) + @eval res = @capture_($EXS[$i], $(PATS[j])) + if i != j + @test res == false + else + @test res == true + if i==1 + @test A == :A + @test ijk == [:i, :j, :k] + elseif i==3 + @test C == :C + @test d == :dee + elseif i==5 + @test left == :g + @test right == :(h+i) + elseif i==7 + @test emm == [:m, :n, 0] + end + end + end + end + + @test !@capture_( :(f(1,2,3)), f_(x_) ) + +end From 4d342cf8dbc61e796cf35939003b4803b7310726 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Wed, 8 Apr 2020 08:25:18 +0200 Subject: [PATCH 4/4] lower optimisation level trick from Plots --- src/TensorCast.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/TensorCast.jl b/src/TensorCast.jl index e30be5b..7c586d3 100644 --- a/src/TensorCast.jl +++ b/src/TensorCast.jl @@ -1,6 +1,12 @@ module TensorCast +# This speeds up loading a bit... but might slow down functions which act on data: +# https://github.com/JuliaPlots/Plots.jl/pull/2544/files +if isdefined(Base, :Experimental) && isdefined(Base.Experimental, Symbol("@optlevel")) + @eval Base.Experimental.@optlevel 1 +end + export @cast, @reduce, @matmul, @pretty using MacroTools, StaticArrays, Compat