@@ -17,7 +17,8 @@ julia> @reduce S[i] := sum(j) M[i,j] + 1000
1717
1818julia> @pretty @reduce S[i] := sum(j) M[i,j] + 1000
1919begin
20- S = transmute(sum(@__dot__(M + 1000), dims = 2), (1,))
20+ ndims(M) == 2 || throw(ArgumentError("expected a 2-tensor M[i, j]"))
21+ S = dropdims(sum(@__dot__(M + 1000), dims = 2), dims=2)
2122end
2223```
2324
@@ -26,9 +27,6 @@ Note that:
2627* The sum applies to the whole right side (including the 1000 here).
2728* And the summed dimensions are always dropped (unless you explicitly say ` S[i,_] := ... ` ).
2829
29- Here ` transmute(..., (1,)) ` is equivalent to ` dropdims(..., dims=2) ` ,
30- keeping just the first dimension by reshaping.
31-
3230## Not just ` sum `
3331
3432You may use any reduction funciton which understands keyword ` dims=... ` , like ` sum ` does.
@@ -112,10 +110,10 @@ There is no need to name the intermediate array, here `termite[x]`, but you must
112110``` julia-repl
113111julia> @pretty @reduce sum(x,θ) L[x,θ] * p[θ] * log(L[x,θ] / @reduce _[x] := sum(θ′) L[x,θ′] * p[θ′])
114112begin
115- local fish = transmute(p, (nothing, 1))
116- termite = transmute(sum(@__dot__(L * fish), dims = 2), (1, ))
117- local wallaby = transmute(p, (nothing, 1))
118- rat = sum(@__dot__(L * wallaby * log(L / termite )))
113+ ndims(L) == 2 || error() # etc, some checks
114+ local goshawk = transmute(p, (nothing, 1 ))
115+ sandpiper = dropdims(sum(@__dot__(L * goshawk), dims = 2), dims = 2) # inner sum
116+ bison = sum(@__dot__(L * goshawk * log(L / sandpiper )))
119117end
120118```
121119
@@ -131,8 +129,9 @@ before summing over one index:
131129``` julia-repl
132130julia> @pretty @reduce R[i,k] := sum(j) M[i,j] * N[j,k]
133131begin
134- local fish = transmute(N, (nothing, 1, 2)) # fish = reshape(N, 1, size(N)...)
135- R = transmute(sum(@__dot__(M * fish), dims = 2), (1, 3)) # R = dropdims(sum(...), dims=2)
132+ size(M, 2) == size(N, 1) || error() # etc, some checks
133+ local fish = transmute(N, (nothing, 1, 2)) # fish = reshape(N, 1, size(N)...)
134+ R = dropdims(sum(@__dot__(M * fish), dims = 2), dims = 2)
136135end
137136```
138137
0 commit comments