Skip to content

Commit b0cc779

Browse files
committed
Merge branch 'kf/fastinference'
2 parents c3b1827 + 3326b09 commit b0cc779

File tree

1 file changed

+44
-26
lines changed

1 file changed

+44
-26
lines changed

src/modeling/determinestrategy.jl

Lines changed: 44 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -690,35 +690,53 @@ function stride_penalty(ls::LoopSet, order::Vector{Symbol})
690690
10.0sum(maximum, values(stridepenaltydict)) * Base.power_by_squaring(0.0009765625, length(order))
691691
end
692692
end
693-
function isoptranslation(ls::LoopSet, op::Operation, unrollsyms::UnrollSymbols)
694-
@unpack u₁loopsym, u₂loopsym, vloopsym = unrollsyms
695-
(vloopsym == u₁loopsym || vloopsym == u₂loopsym) && return 0, 0x00
696-
(isu₁unrolled(op) && isu₂unrolled(op)) || return 0, 0x00
697-
u₁step = step(getloop(ls, u₁loopsym))
698-
u₂step = step(getloop(ls, u₂loopsym))
699-
(isknown(u₁step) & isknown(u₂step)) || return 0, 0x00
700-
abs(gethint(u₁step)) == abs(gethint(u₂step)) || return 0, 0x00
701693

702-
istranslation = 0
703-
inds = getindices(op); li = op.ref.loopedindex
704-
for i eachindex(li)
705-
if !li[i]
706-
opp = findparent(ls, inds[i + (first(inds) === DISCONTIGUOUS)])
707-
if isu₁unrolled(opp) && isu₂unrolled(opp)
708-
if Base.sym_in(instruction(opp).instr, (:vadd_nsw, :(+)))
709-
return i, 0x03 # 00000011 - both positive
710-
elseif Base.sym_in(instruction(opp).instr, (:vsub_nsw, :(-)))
711-
oppp1 = parents(opp)[1]
712-
if isu₁unrolled(oppp1)
713-
return i, 0x01 # 00000001 - u₁ positive, u₂ negative
714-
else#isu₂unrolled(oppp1)
715-
return i, 0x02 # 00000010 - u₂ positive, u₁ negative
716-
end
717-
end
718-
end
694+
function isuniqueinindices(ls::LoopSet, op::Operation, opp::Operation, i::Int)
695+
ld = loopdependencies(opp);
696+
inds = getindicesonly(op)
697+
li = op.ref.loopedindex
698+
for j eachindex(inds)
699+
i == j && continue
700+
if li[j]
701+
inds[j] ld && return false
702+
else
703+
opp = findparent(ls, inds[i + (first(inds) === DISCONTIGUOUS)])
704+
any(in(ld), loopdependencies(opp)) && return false
705+
end
706+
end
707+
true
708+
end
709+
function isoptranslation(ls::LoopSet, op::Operation, unrollsyms::UnrollSymbols)
710+
@unpack u₁loopsym, u₂loopsym, vloopsym = unrollsyms
711+
(vloopsym == u₁loopsym || vloopsym == u₂loopsym) && return 0, 0x00
712+
(isu₁unrolled(op) && isu₂unrolled(op)) || return 0, 0x00
713+
u₁step = step(getloop(ls, u₁loopsym))
714+
u₂step = step(getloop(ls, u₂loopsym))
715+
(isknown(u₁step) & isknown(u₂step)) || return 0, 0x00
716+
abs(gethint(u₁step)) == abs(gethint(u₂step)) || return 0, 0x00
717+
718+
istranslation = 0
719+
inds = getindices(op); li = op.ref.loopedindex
720+
for i eachindex(li)
721+
if !li[i]
722+
opp = findparent(ls, inds[i + (first(inds) === DISCONTIGUOUS)])
723+
if isu₁unrolled(opp) & isu₂unrolled(opp)
724+
if Base.sym_in(instruction(opp).instr, (:vadd_nsw, :(+)))
725+
isuniqueinindices(ls, op, opp, i) || return 0, 0x00
726+
return i, 0x03 # 00000011 - both positive
727+
elseif Base.sym_in(instruction(opp).instr, (:vsub_nsw, :(-)))
728+
isuniqueinindices(ls, op, opp, i) || return 0, 0x00
729+
oppp1 = parents(opp)[1]
730+
if isu₁unrolled(oppp1)
731+
return i, 0x01 # 00000001 - u₁ positive, u₂ negative
732+
else#isu₂unrolled(oppp1)
733+
return i, 0x02 # 00000010 - u₂ positive, u₁ negative
734+
end
719735
end
736+
end
720737
end
721-
0, 0x00
738+
end
739+
0, 0x00
722740
end
723741
# function maxnegativeoffset_old(ls::LoopSet, op::Operation, u::Symbol)
724742
# opmref = op.ref

0 commit comments

Comments
 (0)