Skip to content

Commit de7b51d

Browse files
committed
Support AbstractRange with check_args, and have broadcasting behave more like loops w/ respect to inlining and check_args (remove StridedArray restrictions).
1 parent 5f29cf7 commit de7b51d

File tree

2 files changed

+23
-12
lines changed

2 files changed

+23
-12
lines changed

src/broadcast.jl

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ end
338338
# size of dest determines loops
339339
# function vmaterialize!(
340340
@generated function vmaterialize!(
341-
dest::StridedArray{T,N}, bc::BC, ::Val{Mod}
341+
dest::AbstractArray{T,N}, bc::BC, ::Val{Mod}
342342
) where {T <: NativeTypes, N, BC <: Union{Broadcasted,Product}, Mod}
343343
# we have an N dimensional loop.
344344
# need to construct the LoopSet
@@ -358,18 +358,23 @@ end
358358
add_simple_store!(ls, :dest, ArrayReference(:dest, loopsyms), elementbytes)
359359
resize!(ls.loop_order, num_loops(ls)) # num_loops may be greater than N, eg Product
360360
# return ls
361-
q = lower(ls)
361+
q = lower(ls, 0)
362362
push!(q.args, :dest)
363363
# @show q
364364
# q
365-
q = Expr(:block, ls.prepreamble, Expr(:if, check_args_call(ls), q, :(Base.Broadcast.materialize!(dest, bc))))
366-
isone(N) && pushfirst!(q.args, Expr(:meta,:inline))
365+
q = Expr(
366+
:block,
367+
ls.prepreamble,
368+
# Expr(:if, check_args_call(ls), Expr(:block, :(println("Primary code path!")), q), Expr(:block, :(println("Back up code path!")), :(Base.Broadcast.materialize!(dest, bc))))
369+
Expr(:if, check_args_call(ls), q, :(Base.Broadcast.materialize!(dest, bc)))
370+
)
371+
# isone(N) && pushfirst!(q.args, Expr(:meta,:inline))
367372
q
368373
# ls
369374
end
370375
@generated function vmaterialize!(
371376
dest′::Union{Adjoint{T,A},Transpose{T,A}}, bc::BC, ::Val{Mod}
372-
) where {T <: NativeTypes, N, A <: StridedArray{T,N}, BC <: Union{Broadcasted,Product}, Mod}
377+
) where {T <: NativeTypes, N, A <: AbstractArray{T,N}, BC <: Union{Broadcasted,Product}, Mod}
373378
# we have an N dimensional loop.
374379
# need to construct the LoopSet
375380
loopsyms = [gensym(:n) for n 1:N]
@@ -387,25 +392,30 @@ end
387392
add_broadcast!(ls, :dest, :bc, loopsyms, BC, elementbytes)
388393
add_simple_store!(ls, :dest, ArrayReference(:dest, reverse(loopsyms)), elementbytes)
389394
resize!(ls.loop_order, num_loops(ls)) # num_loops may be greater than N, eg Product
390-
q = lower(ls)
395+
q = lower(ls, 0)
391396
push!(q.args, :dest′)
392-
q = Expr(:block, ls.prepreamble, Expr(:if, check_args_call(ls), q, :(Base.Broadcast.materialize!(dest′, bc))))
393-
isone(N) && pushfirst!(q.args, Expr(:meta,:inline))
397+
q = Expr(
398+
:block,
399+
ls.prepreamble,
400+
Expr(:if, check_args_call(ls), q, :(Base.Broadcast.materialize!(dest′, bc)))
401+
)
402+
# isone(N) && pushfirst!(q.args, Expr(:meta,:inline))
394403
q
395404
# ls
396405
end
397-
function vmaterialize!(
398-
dest::StridedArray{T,N}, bc::Broadcasted{Base.Broadcast.DefaultArrayStyle{0},Nothing,typeof(identity),Tuple{T2}}, ::Val{Mod}
406+
# these are marked `@inline` so the `@avx` itself can choose whether or not to inline.
407+
@inline function vmaterialize!(
408+
dest::AbstractArray{T,N}, bc::Broadcasted{Base.Broadcast.DefaultArrayStyle{0},Nothing,typeof(identity),Tuple{T2}}, ::Val{Mod}
399409
) where {T <: NativeTypes, N, T2 <: Number, Mod}
400410
arg = T(first(bc.args))
401411
@avx for i eachindex(dest)
402412
dest[i] = arg
403413
end
404414
dest
405415
end
406-
function vmaterialize!(
416+
@inline function vmaterialize!(
407417
dest′::Union{Adjoint{T,A},Transpose{T,A}}, bc::Broadcasted{Base.Broadcast.DefaultArrayStyle{0},Nothing,typeof(identity),Tuple{T2}}, ::Val{Mod}
408-
) where {T <: NativeTypes, N, A <: StridedArray{T,N}, T2 <: Number, Mod}
418+
) where {T <: NativeTypes, N, A <: AbstractArray{T,N}, T2 <: NativeTypes, Mod}
409419
arg = T(first(bc.args))
410420
dest = parent(dest′)
411421
@avx for i eachindex(dest)

src/condense_loopset.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,7 @@ end
263263
@inline check_args(::VectorizationBase.AbstractStridedPointer) = true
264264
@inline check_args(_) = false
265265
@inline check_args(A, B, C::Vararg{Any,K}) where {K} = check_args(A) && check_args(B, C...)
266+
@inline check_args(::AbstractRange{T}) where {T} = check_type(T)
266267
"""
267268
check_type(::Type{T}) where {T}
268269

0 commit comments

Comments
 (0)