Skip to content

Commit 80b28d2

Browse files
committed
Use combinemasks function; dispatches to & given shared types, and ? when mixing bools and others.
1 parent 9f5ea41 commit 80b28d2

File tree

3 files changed

+7
-3
lines changed

3 files changed

+7
-3
lines changed

src/lower_store.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ function lower_conditionalstore_scalar!(
7070
end
7171
nothing
7272
end
73+
@inline combinemasks(a::Unsigned, b::Unsigned) = a & b
74+
@inline combinemasks(a::Unsigned, b::Bool) = b ? a : zero(a)
75+
@inline combinemasks(a::Bool, b::Unsigned) = a ? b : zero(b)
76+
@inline combinemasks(a::Bool, b::Bool) = a & b
7377
function lower_conditionalstore_vectorized!(
7478
q::Expr, op::Operation, vectorized::Symbol, W::Symbol, unrolled::Symbol, tiled::Symbol, U::Int,
7579
suffix::Union{Nothing,Int}, mask::Union{Nothing,Symbol,Unsigned}, isunrolled::Bool
@@ -101,7 +105,7 @@ function lower_conditionalstore_vectorized!(
101105
condvarname = varassignname(condvar, u, condunrolled)
102106
instrcall = Expr(:call, lv(:vstore!), ptr, name, mo)
103107
if mask !== nothing && (vecnotunrolled || u == U - 1)
104-
push!(instrcall.args, Expr(:call, :&, condvarname, mask))
108+
push!(instrcall.args, Expr(:call, lv(:combinemasks), condvarname, mask))
105109
else
106110
push!(instrcall.args, condvarname)
107111
end

src/operation_evaluation_order.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
12
function set_upstream_family!(adal::Vector{T}, op::Operation, val::T, ld::Vector{Symbol}, id::Int) where {T}
23
adal[identifier(op)] == val && return # must already have been set
34
# ld != loopdependencies(op) &&

test/runtests.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1340,7 +1340,6 @@ end
13401340
end
13411341
qq[:,Base.OneTo(maxk)] ./= vec(lse)
13421342
end
1343-
13441343
add_1_dim(x::AbstractArray) = reshape(x, size(x)..., 1)
13451344
check_finite(x::AbstractArray) = all(isfinite.(x)) || throw(error("x not finite!"))
13461345
function softmax3_setup!(q::AA, lse::A, tmpmax::A, x::AA, maxk=size(q, ndims(q)) ) where {T<:Real, A<:Array{T}, AA<:AbstractArray{T}}
@@ -1492,7 +1491,7 @@ end
14921491
@test y1 y2
14931492

14941493

1495-
ni, nj, nk = (100, 100, 10)
1494+
ni, nj, nk = (127, 113, 13)
14961495
x = rand(T, ni, nj, nk);
14971496
q1 = similar(x);
14981497
q2 = similar(x);

0 commit comments

Comments
 (0)