Skip to content

Commit a6ca2d1

Browse files
committed
fix reductions with only naked indices
1 parent a2a4996 commit a6ca2d1

File tree

2 files changed

+28
-9
lines changed

2 files changed

+28
-9
lines changed

src/macro.jl

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -582,8 +582,8 @@ function matmultarget(ex, target, parsed, store::NamedTuple, call::CallInfo)
582582
@capture(ex, A_ * B_ * C__ | *(A_, B_, C__) ) || throw(MacroError("can't @matmul that!", call))
583583

584584
# Figure out what to sum over, and make A,B into matrices ready for *
585-
iA = guesstarget(A)
586-
iB = guesstarget(B)
585+
iA = guesstarget(A, [], [])
586+
iB = guesstarget(B, [], [])
587587

588588
isum = sort(intersect(iA, iB, parsed.reduced),
589589
by = i -> findfirst(isequal(i), target)) # or target? parsed.reduced
@@ -857,11 +857,14 @@ function reduceparse(ex1, ex2, store::NamedTuple, call::CallInfo)
857857
canon = vcat(leftcanon, reduced)
858858
else
859859
# But for Z = sum(A, dims=...) can try to avoid permutedims, not sure it matters.
860-
guess = guesstarget(ex2) # TODO make guess smarter, use leftcanon as a target
861-
# @show guess reduced
862-
[ deleteat!(guess, findcheck(i, guess, call, " on the right")) for i in reduced ]
863-
if leftcanon == guess
864-
canon = guesstarget(ex2)
860+
guess = guesstarget(ex2, leftcanon, reduced)
861+
guessminus = copy(guess)
862+
for i in reduced
863+
deleteat!(guessminus, findcheck(i, guessminus, call, " on the right"))
864+
end
865+
if leftcanon == guessminus # i.e. leftcanon is an ordered subset of guess
866+
canon = guess
867+
# canon == vcat(leftcanon, reduced) || @info "guesstarget did something!" repr(leftcanon) repr(reduced) repr(guess) ex2
865868
else
866869
canon = vcat(leftcanon, reduced)
867870
end
@@ -1254,16 +1257,26 @@ function listindices(ex::Expr)
12541257
list
12551258
end
12561259

1257-
function guesstarget(ex::Expr)
1260+
listsymbols(s::Symbol, target) = s in target ? [s] : Symbol[]
1261+
listsymbols(any, target) = Symbol[]
1262+
function listsymbols(ex::Expr, target)
1263+
ex.head == :vec && return Symbol[]
1264+
return union((listsymbols(a, target) for a in ex.args)...)
1265+
end
1266+
1267+
function guesstarget(ex::Expr, left, red)
12581268
list = sort(listindices(ex), by=length, rev=true)
1259-
shortlist = unique(reduce(vcat, list))
1269+
naked = listsymbols(ex, vcat(left, red))
1270+
unique(vcat(list..., naked)) # TODO make a smarter version which tries to fit to left + red?
12601271
end
12611272

12621273
# function overlapsorted(x,y) # works fine but not in use yet
12631274
# z = intersect(x,y)
12641275
# length(z) ==0 && return true
12651276
# xi = map(i -> findfirst(isequal(i),x), z)
1277+
# @assert xi == indexin(z, x)
12661278
# yi = map(i -> findfirst(isequal(i),y), z)
1279+
# @assert xi == indexin(z, y)
12671280
# return sortperm(xi) == sortperm(yi)
12681281
# end
12691282

test/four.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,12 @@ end
7373

7474
@cast C[i,j,k] := 0 * A[i,(j,k)] + j (k in 1:2) # used to infer sz_j = (:)
7575
@test all(==(2), C[:,2,:])
76+
77+
@reduce D[k] := sum(i) B[i]/k (k in 1:4) # no indexing by k on RHS
78+
@test D vec(sum(B ./ (1:4)', dims=1))
79+
80+
@reduce E := sum(i,k) i/k (i in 1:2, k in 1:4) # no indexing on RHS of reduction
81+
@test E sum((1:2) ./ (1:4)')
7682
end
7783

7884
@testset "tuples" begin

0 commit comments

Comments
 (0)