Skip to content

Commit 3f679d9

Browse files
committed
use dropdims again, after reductions
1 parent 003b06f commit 3f679d9

File tree

2 files changed

+12
-12
lines changed

2 files changed

+12
-12
lines changed

docs/src/reduce.md

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ julia> @reduce S[i] := sum(j) M[i,j] + 1000
1717
1818
julia> @pretty @reduce S[i] := sum(j) M[i,j] + 1000
1919
begin
20-
S = transmute(sum(@__dot__(M + 1000), dims = 2), (1,))
20+
ndims(M) == 2 || throw(ArgumentError("expected a 2-tensor M[i, j]"))
21+
S = dropdims(sum(@__dot__(M + 1000), dims = 2), dims=2)
2122
end
2223
```
2324

@@ -26,9 +27,6 @@ Note that:
2627
* The sum applies to the whole right side (including the 1000 here).
2728
* And the summed dimensions are always dropped (unless you explicitly say `S[i,_] := ...`).
2829

29-
Here `transmute(..., (1,))` is equivalent to `dropdims(..., dims=2)`,
30-
keeping just the first dimension by reshaping.
31-
3230
## Not just `sum`
3331

3432
You may use any reduction funciton which understands keyword `dims=...`, like `sum` does.
@@ -112,10 +110,10 @@ There is no need to name the intermediate array, here `termite[x]`, but you must
112110
```julia-repl
113111
julia> @pretty @reduce sum(x,θ) L[x,θ] * p[θ] * log(L[x,θ] / @reduce _[x] := sum(θ′) L[x,θ′] * p[θ′])
114112
begin
115-
local fish = transmute(p, (nothing, 1))
116-
termite = transmute(sum(@__dot__(L * fish), dims = 2), (1,))
117-
local wallaby = transmute(p, (nothing, 1))
118-
rat = sum(@__dot__(L * wallaby * log(L / termite)))
113+
ndims(L) == 2 || error() # etc, some checks
114+
local goshawk = transmute(p, (nothing, 1))
115+
sandpiper = dropdims(sum(@__dot__(L * goshawk), dims = 2), dims = 2) # inner sum
116+
bison = sum(@__dot__(L * goshawk * log(L / sandpiper)))
119117
end
120118
```
121119

@@ -131,8 +129,9 @@ before summing over one index:
131129
```julia-repl
132130
julia> @pretty @reduce R[i,k] := sum(j) M[i,j] * N[j,k]
133131
begin
134-
local fish = transmute(N, (nothing, 1, 2)) # fish = reshape(N, 1, size(N)...)
135-
R = transmute(sum(@__dot__(M * fish), dims = 2), (1, 3)) # R = dropdims(sum(...), dims=2)
132+
size(M, 2) == size(N, 1) || error() # etc, some checks
133+
local fish = transmute(N, (nothing, 1, 2)) # fish = reshape(N, 1, size(N)...)
134+
R = dropdims(sum(@__dot__(M * fish), dims = 2), dims = 2)
136135
end
137136
```
138137

src/macro.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1436,8 +1436,9 @@ function newoutput(ex, canon, parsed, store::NamedTuple, call::CallInfo)
14361436
ex = :( $(parsed.redfun)($ex) )
14371437
else
14381438
dims = length(parsed.rdims)>1 ? Tuple(parsed.rdims) : parsed.rdims[1]
1439-
perm = Tuple(filter(d -> !(d in parsed.rdims), 1:length(canon)))
1440-
ex = :( TensorCast.transmute($(parsed.redfun)($ex, dims=$dims), $perm) )
1439+
# perm = Tuple(filter(d -> !(d in parsed.rdims), 1:length(canon)))
1440+
# ex = :( TensorCast.transmute($(parsed.redfun)($ex, dims=$dims), $perm) )
1441+
ex = :( Base.dropdims($(parsed.redfun)($ex, dims=$dims), dims=$dims) )
14411442
if :strided in call.flags
14421443
pop!(call.flags, :collected, :ok) # makes stridedview(...
14431444
end

0 commit comments

Comments
 (0)