diff --git a/Project.toml b/Project.toml index 7946689..8e2c1ad 100644 --- a/Project.toml +++ b/Project.toml @@ -5,10 +5,12 @@ version = "0.0.1" [deps] DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" +EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949" Requires = "ae029012-a4dd-5104-9daa-d747884805df" [compat] DiffRules = "1" +EllipsisNotation = "0.4" Requires = "1" julia = "1.3" diff --git a/src/Tullio.jl b/src/Tullio.jl index 365ff63..455bffa 100644 --- a/src/Tullio.jl +++ b/src/Tullio.jl @@ -1,5 +1,8 @@ module Tullio +using EllipsisNotation: (..) +using Base.Broadcast: newindex, newindexer, combine_axes # , check_broadcast_axes + #========== ⚜️ ==========# export @tullio diff --git a/src/macro.jl b/src/macro.jl index 4067411..7552ed3 100644 --- a/src/macro.jl +++ b/src/macro.jl @@ -68,6 +68,7 @@ function _tullio(exs...; mod=Main) # Everything writes into leftarray[leftraw...], sometimes with a generated name leftraw = [], leftind = Symbol[], # vcat(leftind, redind) is the complete list of loop indices + leftcast = nothing, leftarray = nothing, leftscalar = nothing, # only defined for scalar reduction leftnames = Symbol[], # for NamedDims @@ -75,8 +76,11 @@ function _tullio(exs...; mod=Main) right = nothing, rightind = Symbol[], sharedind = Symbol[], # indices appearing on every RHS array + rightcast = Expr[], # things like @view A[end,end,..] for broadcasting arrays = Symbol[], scalars = Symbol[], + broadargs = Symbol[], # results of newindexer, passed to actor as tuple BROAD + broadarrays = Symbol[], # arrays for which that must be unpacked for newindex cost = 1, # Index ranges: first save all known constraints constraints = Dict{Symbol,Vector}(), # :k => [:(axis(A,2)), :(axis(B,1))] etc. @@ -203,11 +207,12 @@ end # These only need not to clash with symbols in the input: RHS, AXIS = :π“‡π’½π“ˆ, :𝒢𝓍 -ZED, TYP, ACC, KEEP = :β„›, :𝒯, :π’œπ’Έπ’Έ, :β™» -EPS, DEL, EXPR = :πœ€, :π›₯, :ℰ𝓍 +ZED, TYP, ACC, KEEP = :β„›, :𝒯, :π’œπ’Έπ’Έ, :β™» # used in act! function +EPS, DEL, EXPR = :πœ€, :π›₯, :ℰ𝓍 # used for derivatives +CART, IND, BOOL, FIRST, BROAD = :π’žπ’Άπ“‡π“‰, :𝒾𝓃𝒹, :𝒷ℴℴ𝓁, :π’»π’Ύπ“‡π“ˆπ“‰, :𝒷𝓇ℴ𝒢𝒹 # used for broadcasting βšƒ # These get defined globally, with a random number appended: -MAKE, ACT! = :π’žπ“‡β„―π’Άπ“‰β„―, :π’œπ’Έπ“‰! # :ℳ𝒢𝓀ℯ +MAKE, ACT! = :ℳ𝒢𝓀ℯ, :π’œπ’Έπ“‰! #========== input parsing ==========# @@ -250,6 +255,7 @@ function parse_input(expr, store) unique!(store.leftind) # after last saveconstraints() unique!(store.sharedind) unique!(store.rightind) + unique!(store.broadargs) unique!(store.outpre) # kill mutiple assertions, and evaluate any f(A) only once store.redind = setdiff(store.rightind, store.leftind) @@ -281,7 +287,7 @@ rightwalk(store) = ex -> begin # Third, save letter A, and what axes(A) says about indices: push!(store.arrays, arrayonly(A)) inds = primeindices(inds) - saveconstraints(A, inds, store, true) + inds = saveconstraints(A, inds, store, true) # Re-assemble RHS with new A, and primes on indices taken care of. return :( $A[$(inds...)] ) @@ -297,7 +303,25 @@ saveconstraints(A, inds, store, right=true) = begin A1 = arrayfirst(A) is = Symbol[] foreach(enumerate(inds)) do (d,ex) - isconst(ex) && return + isconst(ex) && return # ?? now that saveconstraints() returns inds, you could do dollars here, is that better? You still need to catch any outside of indexing. + + if ex in (:(..), CART) # broadcasting! + d == length(inds) || error("can only use .. for broadcasting except after explicit indices") + ends = repeat([:end], d-1) + vex = :(@view $A1[$(ends...),$..]) + if right + push!(store.rightcast, vex) # will be used to compute range of CART + push!(store.broadarrays, A1) + inds[d] = Symbol(IND, A1) + else + store.leftcast = vex + end + boolA, firstA = Symbol(BOOL, A1), Symbol(FIRST, A1) + push!(store.axisdefs, :(local $boolA, $firstA = $newindexer($vex))) # done inside maker + push!(store.broadargs, boolA, firstA) # pass all from maker to actor as tuple BROAD + return # editing of RHS must happen later + end + range_i, i = range_expr_walk(length(inds)==1 ? :(eachindex($A1)) : :(axes($A1,$d)), ex) if i isa Symbol push!(is, i) @@ -309,6 +333,7 @@ saveconstraints(A, inds, store, right=true) = begin push!(store.shiftedind, i...) push!(store.pairconstraints, (i..., dollarstrip.(range_i)...)) end + end if right append!(store.rightind, is) @@ -323,13 +348,18 @@ saveconstraints(A, inds, store, right=true) = begin append!(store.leftind, is) # why can's this be the only path for store.leftind?? end n = length(inds) - if n==1 + if !isempty(store.rightcast) + str = "expected a $A1 to have at least $(n-1) dimensions" + push!(store.outpre, :( ndims($A1) >= $(n-1) || error($str) )) + elseif n==1 str = "expected a 1-array $A1, or a tuple" push!(store.outpre, :( $A1 isa Tuple || ndims($A1) == 1 || error($str) )) else str = "expected a $n-array $A1" # already arrayfirst(A) push!(store.outpre, :( ndims($A1) == $n || error($str) )) end + + inds end arrayfirst(A::Symbol) = A # this is for axes(A,d), axes(first(B),d), etc. @@ -352,7 +382,7 @@ dollarwalk(store) = ex -> begin @nospecialize ex ex isa Expr || return ex if ex.head == :call - ex.args[1] == :* && ex.args[2] === Int(0) && return false # tidy up dummy arrays! + ex.args[1] == :* && ex.args[2] === Int(0) && return false # tidy up dummy arrays! ?? these were needed before explicit ranges, can delete callcost(ex.args[1], store) # cost model for threading elseif ex.head == :$ # interpolation of $c things: ex.args[1] isa Symbol || error("you can only interpolate single symbols, not $ex") @@ -378,6 +408,8 @@ tidyleftraw(leftraw, store) = map(leftraw) do i end elseif i === :_ return 1 + elseif i === :(..) # broadcasting! + return CART # This symbol ends up in leftind, which is good. But for in-pace, it's too early, will miss saveconstraints... unless that looks for CART too... end i end @@ -431,6 +463,7 @@ end function index_ranges(store) todo = Set(vcat(store.leftind, store.redind)) + pop!(todo, CART, nothing) for (i,j,r_i,r_j) in store.pairconstraints if haskey(store.constraints, i) # && i in todo ?? @@ -455,6 +488,34 @@ function index_ranges(store) end end + if isempty(store.rightcast) # no broadcasting, but make some trivial definitions + axs = Symbol(CART, :axes) + push!(store.axisdefs, quote + local $axs = () + local $BROAD = () + end) + else # broadcasting! + axs = Symbol(CART, :axes) # this is the tuple of axes, also used for making a new array + carts = Symbol(AXIS, CART) # this is the CartesianIndex iterated over + push!(store.axisdefs, quote + local $axs = $combine_axes($(store.rightcast...)) + local $carts = $CartesianIndices($axs) + local $BROAD = tuple($(store.broadargs...)) + end) + if !(:newarray in store.flags) + push!(store.axisdefs, :($axs == $axes($(store.leftcast)) || + throw(DimensionMismatch(("LHS does not match broadcast dimensions from RHS"))))) + end + # Now also deal with RHS, where we must use CART + broadargs to calculate IND_A + rightex = map(store.broadarrays) do A + indA = Symbol(IND, A) + boolA, firstA = Symbol(BOOL, A), Symbol(FIRST, A) + :($indA = $newindex($CART, $boolA, $firstA)) + end + store.right = :($(rightex...); $(store.right)) + + end + append!(store.outex, store.axisdefs) end @@ -500,6 +561,7 @@ function output_array(store) # This now checks for OffsetArrays, and allows A[i,1] := ... outaxes = map(store.leftraw) do i i isa Integer && i==1 && return :(Base.OneTo(1)) + i == CART && return :($(Symbol(CART, :axes))...) i isa Symbol && return Symbol(AXIS, i) error("can't use index $i on LHS for a new array") end @@ -515,7 +577,7 @@ function output_array(store) simex = if isempty(store.arrays) # :( zeros($TYP, tuple($(outaxes...))) ) # Array{T} doesn't accept ranges... but zero() doesn't accept things like @tullio [i,j] := (i,j) i ∈ 2:3, j ∈ 4:5 - :( similar([], $TYP, tuple($(outaxes...))) ) + :( similar([], $TYP, tuple($(outaxes...)),) ) else :( similar($(store.arrays[1]), $TYP, tuple($(outaxes...),)) ) end @@ -561,10 +623,14 @@ function action_functions(store) push!(store.outeval, quote function $make($(store.arrays...), $(store.scalars...), ) $sofar - $threader($act!, $ST, $(store.leftarray), - tuple($(store.arrays...), $(store.scalars...),), - tuple($(axisleft...),), tuple($(axisred...),); - block=$block, keep=$keep) + # $threader($act!, $ST, $(store.leftarray), + # tuple($(store.arrays...), $(store.scalars...),), + # tuple($(axisleft...),), tuple($(axisred...),); # missing BROAD + # block=$block, keep=$keep) + $act!($ST, $(store.leftarray), + $(store.arrays...), $(store.scalars...), + $(axisleft...), $(axisred...), $BROAD, + $keep) return $(store.leftarray) end end) @@ -636,15 +702,37 @@ function action_functions(store) store.threads==true ? (BLOCK[] Γ· store.cost) : store.threads push!(store.outex, quote - $threader($act!, $ST, $(store.leftarray), - tuple($(store.arrays...), $(store.scalars...),), - tuple($(axisleft...),), tuple($(axisred...),); - block = $block, keep = $keep) + # $threader($act!, $ST, $(store.leftarray), + # tuple($(store.arrays...), $(store.scalars...),), + # tuple($(axisleft...),), tuple($(axisred...),); + # block = $block, keep = $keep) + $act!($ST, $(store.leftarray), + $(store.arrays...), $(store.scalars...), + $(axisleft...), $(axisred...), $BROAD, + $keep) $(store.leftarray) end) end end +#= +I don't like constructing threader($act.. twice + +I'd like to only @eval if doing gradients + +Order: +* without grads, act! should be defined before make. Or perhaps there is no make function then. +* with grads, both make and act! should be defined before @adjoint etc. + + +First I add stuff to make a new matrix +If I need make(), I can slurp that up & make a function + +Once I add act!, I can no longer slurp. +Function act! should come first, but if it's not a function, then later... + +=# + """ make_many_actors(f!, args, ex1, [:i,], ex3, [:k,], ex5, ex6, store) @@ -653,7 +741,7 @@ This makes several functions of this form, decorated as necessary with `@inbouds` or `@avx` etc, and with appropriate `storage_type` as the first argument. ``` -f!(::Type, args..., keep=nothing) where {T} +f!(::Type, args..., broad=(), keep=nothing) where {T} ex1 ex2 = (for i in axis_i ex3 @@ -667,12 +755,17 @@ end """ function make_many_actors(act!, args, ex1, outer::Vector, ex3, inner::Vector, ex5, ex6, store) + if !isempty(store.broadargs) + bex = :(($(store.broadargs...),) = $BROAD) # broadcasting! unpack the extra arguments + ex1 = :($bex; $ex1) + end + ex4 = recurseloops(ex5, inner) ex2 = recurseloops(:($ex3; $ex4; $ex6), outer) push!(store.outeval, quote - function $act!(::Type, $(args...), $KEEP=nothing) where {$TYP} + function $act!(::Type, $(args...), $BROAD=(), $KEEP=nothing) where {$TYP} @debug "base actor:" typeof.(tuple($(args...))) @inbounds @fastmath ($ex1; $ex2) end @@ -690,7 +783,7 @@ function make_many_actors(act!, args, ex1, outer::Vector, ex3, inner::Vector, ex unroll = store.avx == true ? 0 : store.avx # unroll=0 is the default setting push!(store.outeval, quote - function $act!(::Type{<:Array{<:Union{Base.HWReal, Bool}}}, $(args...), $KEEP=nothing) where {$TYP} + function $act!(::Type{<:Array{<:Union{Base.HWReal, Bool}}}, $(args...), $BROAD=(), $KEEP=nothing) where {$TYP} @debug "LoopVectorization @avx actor, unroll=$unroll" $expre LoopVectorization.@avx unroll=$unroll $exloop @@ -711,12 +804,12 @@ function make_many_actors(act!, args, ex1, outer::Vector, ex3, inner::Vector, ex sizes = map(ax -> :(length($ax)), axouter) push!(store.outeval, quote - KernelAbstractions.@kernel function $kernel($(args...), $KEEP) where {$TYP} + KernelAbstractions.@kernel function $kernel($(args...), $BROAD, $KEEP) where {$TYP} ($(outer...),) = @index(Global, NTuple) ($ex1; $ex3; $ex4; $ex6) end - function $act!(::Type{<:CuArray}, $(args...), $KEEP=nothing) where {$TYP} + function $act!(::Type{<:CuArray}, $(args...), $BROAD=(), $KEEP=nothing) where {$TYP} @debug "KernelAbstractions CuArrays actor" cu_kern! = $kernel(CUDA(), $(store.cuda)) $(asserts...) @@ -725,7 +818,7 @@ function make_many_actors(act!, args, ex1, outer::Vector, ex3, inner::Vector, ex end # Just for testing really... - function $act!(::Type{<:Array}, $(args...), $KEEP=nothing) where {$TYP} + function $act!(::Type{<:Array}, $(args...), $BROAD=(), $KEEP=nothing) where {$TYP} @debug "KernelAbstractions CPU actor:" typeof.(tuple($(args...))) cpu_kern! = $kernel(CPU(), Threads.nthreads()) $(asserts...) diff --git a/src/threads.jl b/src/threads.jl index 3b85a9e..d2690a2 100644 --- a/src/threads.jl +++ b/src/threads.jl @@ -38,6 +38,8 @@ Then it divides up the other axes, each accumulating in its own copy of `Z`. `keep=nothing` means that it overwrites the array, anything else (`keep=true`) adds on. """ function threader(fun!::Function, T::Type, Z::AbstractArray, As::Tuple, I0s::Tuple, J0s::Tuple; block, keep=nothing) + return fun!(T, Z, As..., I0s..., J0s..., keep) + # not yet fixed up for broadcasting Is = map(UnitRange, I0s) Js = map(UnitRange, J0s) if isnothing(block) diff --git a/test/parsing.jl b/test/parsing.jl index c651a68..65047aa 100644 --- a/test/parsing.jl +++ b/test/parsing.jl @@ -161,6 +161,34 @@ end end +@testset "broadcasting" begin + + f1(A) = @tullio C[i, ..] := A[i, ..] + 1 + @test f1(ones(3)) == ones(3) .+ 1 + @test f1(ones(3,4)) == ones(3,4) .+ 1 + @test f1(ones(3,4,5)) == ones(3,4,5) .+ 1 + + f2(A) = @tullio C[i, ..] := A[i, k, ..] + @test f2(ones(3,4)) == fill(4.0, 3) + A3 = rand(3,4,5) + @test f2(A3) β‰ˆ dropdims(sum(A3, dims=2), dims=2) + + f3(A, B) = @tullio C[i,j, ..] := A[i, k, ..] * B[j, k, ..] + A2 = rand(3,3); + B2 = rand(3,3); + @test f3(A2, B2) β‰ˆ A2 * B2' + A3 = rand(3,3,2); + B3 = rand(3,3,2); + C3 = f3(A3, B3) + @test C3[:,:,1] β‰ˆ A3[:,:,1] * B3[:,:,1]' + @test C3[:,:,2] β‰ˆ A3[:,:,2] * B3[:,:,2]' + + C4 = f3(A3, B2) + @test C4[:,:,1] β‰ˆ A3[:,:,1] * B2[:,:]' + @test C4[:,:,2] β‰ˆ A3[:,:,2] * B2[:,:]' + +end + @testset "without packages" begin A = [i^2 for i in 1:10]