Skip to content

Commit 440b838

Browse files
author
Michael Abbott
committed
disable avx for min/max grad
1 parent 1e86d2e commit 440b838

File tree

2 files changed

+4
-6
lines changed

2 files changed

+4
-6
lines changed

src/macro.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -494,8 +494,7 @@ padmodclamp_pair(A, inds, store, assign=false) = begin
494494
elseif ex.args[1] == :pad && length(ex.args) >= 2
495495
i = ex.args[2]
496496
if !all(==(0), ex.args[3:end]) || length(ex.args) == 2
497-
# push!(nopadif, :($i ∈ $axes($A,$d)))
498-
push!(nopadif, :($i >= first(axes($A,$d))), :($i <= Base.last(axes($A,$d)))) # allows avx? Weirdly, deleting "Base." causes errors
497+
push!(nopadif, :($i >= first(axes($A,$d))), :($i <= last(axes($A,$d)))) # allows avx
499498
end
500499
return i
501500
end
@@ -1073,9 +1072,8 @@ function make_many_actors(act!, args, ex1, outer::Vector, ex3, inner::Vector, ex
10731072
safe = if act! == ACT!
10741073
isempty(store.unsafeleft)
10751074
else # working on ∇act!
1076-
isempty(store.unsaferight) # &&
1077-
# store.redfun == :+ && # Disable @avx for min/max grad, #53
1078-
# store.grad != :Dual # and for use with ForwardDiff
1075+
isempty(store.unsaferight)
1076+
store.redfun == :+ # Disable @avx for min/max grad, #53
10791077
end
10801078

10811079
if safe && store.avx != false && isdefined(store.mod, :LoopVectorization)

test/gradients.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
This file is run several times
33
* with grad=Base vs grad=Dual
44
* with Tracker, Zygote
5-
* using KernelAbstractions, LoopVectorization, TensorCast
5+
* using KernelAbstractions, LoopVectorization, TensorOperations
66
=#
77

88
using Tullio, Test, ForwardDiff, Random

0 commit comments

Comments
 (0)