@@ -582,8 +582,8 @@ function matmultarget(ex, target, parsed, store::NamedTuple, call::CallInfo)
582582 @capture (ex, A_ * B_ * C__ | * (A_, B_, C__) ) || throw (MacroError (" can't @matmul that!" , call))
583583
584584 # Figure out what to sum over, and make A,B into matrices ready for *
585- iA = guesstarget (A)
586- iB = guesstarget (B)
585+ iA = guesstarget (A, [], [] )
586+ iB = guesstarget (B, [], [] )
587587
588588 isum = sort (intersect (iA, iB, parsed. reduced),
589589 by = i -> findfirst (isequal (i), target)) # or target? parsed.reduced
@@ -857,11 +857,14 @@ function reduceparse(ex1, ex2, store::NamedTuple, call::CallInfo)
857857 canon = vcat (leftcanon, reduced)
858858 else
859859 # But for Z = sum(A, dims=...) can try to avoid permutedims, not sure it matters.
860- guess = guesstarget (ex2) # TODO make guess smarter, use leftcanon as a target
861- # @show guess reduced
862- [ deleteat! (guess, findcheck (i, guess, call, " on the right" )) for i in reduced ]
863- if leftcanon == guess
864- canon = guesstarget (ex2)
860+ guess = guesstarget (ex2, leftcanon, reduced)
861+ guessminus = copy (guess)
862+ for i in reduced
863+ deleteat! (guessminus, findcheck (i, guessminus, call, " on the right" ))
864+ end
865+ if leftcanon == guessminus # i.e. leftcanon is an ordered subset of guess
866+ canon = guess
867+ # canon == vcat(leftcanon, reduced) || @info "guesstarget did something!" repr(leftcanon) repr(reduced) repr(guess) ex2
865868 else
866869 canon = vcat (leftcanon, reduced)
867870 end
@@ -1254,16 +1257,26 @@ function listindices(ex::Expr)
12541257 list
12551258end
12561259
1257- function guesstarget (ex:: Expr )
1260+ listsymbols (s:: Symbol , target) = s in target ? [s] : Symbol[]
1261+ listsymbols (any, target) = Symbol[]
1262+ function listsymbols (ex:: Expr , target)
1263+ ex. head == :vec && return Symbol[]
1264+ return union ((listsymbols (a, target) for a in ex. args). .. )
1265+ end
1266+
1267+ function guesstarget (ex:: Expr , left, red)
12581268 list = sort (listindices (ex), by= length, rev= true )
1259- shortlist = unique (reduce (vcat, list))
1269+ naked = listsymbols (ex, vcat (left, red))
1270+ unique (vcat (list... , naked)) # TODO make a smarter version which tries to fit to left + red?
12601271end
12611272
12621273# function overlapsorted(x,y) # works fine but not in use yet
12631274# z = intersect(x,y)
12641275# length(z) ==0 && return true
12651276# xi = map(i -> findfirst(isequal(i),x), z)
1277+ # @assert xi == indexin(z, x)
12661278# yi = map(i -> findfirst(isequal(i),y), z)
1279+ # @assert xi == indexin(z, y)
12671280# return sortperm(xi) == sortperm(yi)
12681281# end
12691282
0 commit comments