Skip to content

Commit f1bb602

Browse files
authored
Add lazy broadcasting (#31)
* add lazy broadcasting * magic copy keyword * Revert "magic copy keyword" This reverts commit ff53e25. * change all to apply!, explicitly copy state for update * fixup * change copy scheme * more rules * fixup * delete the manifest * broken
1 parent 44fc0ce commit f1bb602

File tree

7 files changed

+166
-146
lines changed

7 files changed

+166
-146
lines changed

Manifest.toml

Lines changed: 0 additions & 57 deletions
This file was deleted.

Project.toml

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,17 @@
11
name = "Optimisers"
22
uuid = "3bd65402-5787-11e9-1adc-39752487f4e2"
33
authors = ["Mike J Innes <[email protected]>"]
4-
version = "0.1.0"
4+
version = "0.2.0"
55

66
[deps]
77
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
88
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
99
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
10-
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1110
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1211

1312
[compat]
14-
Functors = "0.1, 0.2"
15-
Requires = "0.5, 1"
16-
julia = "1.6"
13+
Functors = "0.2.7"
14+
julia = "1.6"
1715

1816
[extras]
1917
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

docs/src/api.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ Optimisers.ADAMW
1515
Optimisers.AdaBelief
1616
```
1717

18-
```
18+
```@docs
1919
Optimisers.ClipGrad
2020
Optimisers.ClipNorm
2121
Optimisers.WeightDecay

docs/src/index.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@ struct Descent{T}
88
η::T
99
end
1010

11-
# Define an `apply` rule with which to update the current params
11+
# Define an `apply!` rule with which to update the current params
1212
# using the gradients
13-
function Optimisers.apply(o::Descent, state, m, m̄)
13+
function Optimisers.apply!(o::Descent, state, m, m̄)
1414
o.η .* m̄, state
1515
end
1616

@@ -52,3 +52,7 @@ tree formed by the model and update the parameters using the gradients. Optimise
5252
work with different forms of gradients, but most likely use case are the gradients as
5353
returned by [Zygote.jl](https://fluxml.ai/Zygote.jl).
5454

55+
There is also `Optimisers.update!` which similarly returns a new model and new state,
56+
but is free to mutate arrays within the old one, for efficiency.
57+
The method of `apply!` you write is likewise free to mutate arrays within its state;
58+
they are defensively copied when this rule is used with `update`.

src/interface.jl

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
patch(x, x̄) = x .-
21

32
function state(o, x)
43
if isnumeric(x)
@@ -11,27 +10,68 @@ function state(o, x)
1110
end
1211
end
1312

14-
function _update(o, st, x, x̄s...)
15-
st′, x̄′ = apply(o, st, x, x̄s...)
16-
return st′, patch(x, x̄′)
13+
patch!(x, x̄) = iswriteable(x) ? (x .= x .- x̄) : (x .- x̄)
14+
15+
function _update!(o, st, x, x̄s...)
16+
st′, x̄′ = apply!(o, st, x, x̄s...)
17+
return st′, patch!(x, x̄′)
1718
end
1819

19-
function update(o, state, x::T, x̄s...) where T
20+
function update!(o, state, x, x̄s...)
2021
if all(isnothing, x̄s)
2122
return state, x
2223
elseif isnumeric(x)
23-
return _update(o, state, x, x̄s...)
24+
return _update!(o, state, x, x̄s...)
2425
else
2526
x̄s′ = map(x̄ -> functor(typeof(x), x̄)[1], x̄s)
2627
x′, re = functor(typeof(x), x)
27-
xstate = map((stᵢ, xᵢ, x̄sᵢ...) -> update(o, stᵢ, xᵢ, x̄sᵢ...), state, x′, x̄s′...)
28+
xstate = map((stᵢ, xᵢ, x̄sᵢ...) -> update!(o, stᵢ, xᵢ, x̄sᵢ...), state, x′, x̄s′...)
2829
return map(first, xstate), re(map(last, xstate))
2930
end
3031
end
3132

33+
function update(o, state, x, x̄s...)
34+
state′ = fmap(copy, state; exclude = iswriteable)
35+
x′ = fmap(copy, x; exclude = iswriteable)
36+
update!(o, state′, x′, x̄s...)
37+
end
38+
3239
# default all rules to first order calls
33-
apply(o, state, x, dx, dxs...) = apply(o, state, x, dx)
40+
apply!(o, state, x, dx, dxs...) = apply!(o, state, x, dx)
3441

3542
isnumeric(x::AbstractArray{<:Number}) = isleaf(x) # isleaf to allow for e.g. transposed shared weights
3643
isnumeric(x::AbstractArray{<:Bool}) = false # convention of ChainRules is that Bool is non-differentiable
3744
isnumeric(x) = false
45+
46+
iswriteable(::DenseArray{<:AbstractFloat}) = true # more elaborate versions are possible, wait until needed?
47+
iswriteable(_) = false
48+
49+
"""
50+
@.. x = x + y
51+
@.. x + y / z
52+
53+
Magic broadcasting macro, for use in `apply!` rules:
54+
* Applied to assignment `x = ...` it is like `@.` unless `!iswriteable(x)`,
55+
in which case it ignores `x`, and applies `@.` on the right.
56+
* Applied to other expressions, it broadcasts like `@.` but does not materialise,
57+
returning a `Broadcasted` object for later use.
58+
"""
59+
macro var".."(ex)
60+
if Meta.isexpr(ex, :(=))
61+
dst = esc(ex.args[1])
62+
src = esc(Broadcast.__dot__(ex.args[2]))
63+
:(if $iswriteable($dst)
64+
$dst .= $src
65+
else
66+
$src
67+
end)
68+
else
69+
bc = esc(Broadcast.__dot__(ex))
70+
:($lazy.($bc))
71+
end
72+
end
73+
74+
function lazy end
75+
Broadcast.broadcasted(::typeof(lazy), x) = Lazy(x)
76+
struct Lazy{T}; bc::T; end
77+
Broadcast.materialize(x::Lazy) = Broadcast.instantiate(x.bc)

0 commit comments

Comments
 (0)