1
- patch (x, x̄) = x .- x̄
2
1
3
2
function state (o, x)
4
3
if isnumeric (x)
@@ -11,27 +10,68 @@ function state(o, x)
11
10
end
12
11
end
13
12
14
- function _update (o, st, x, x̄s... )
15
- st′, x̄′ = apply (o, st, x, x̄s... )
16
- return st′, patch (x, x̄′)
13
+ patch! (x, x̄) = iswriteable (x) ? (x .= x .- x̄) : (x .- x̄)
14
+
15
+ function _update! (o, st, x, x̄s... )
16
+ st′, x̄′ = apply! (o, st, x, x̄s... )
17
+ return st′, patch! (x, x̄′)
17
18
end
18
19
19
- function update (o, state, x:: T , x̄s... ) where T
20
+ function update! (o, state, x, x̄s... )
20
21
if all (isnothing, x̄s)
21
22
return state, x
22
23
elseif isnumeric (x)
23
- return _update (o, state, x, x̄s... )
24
+ return _update! (o, state, x, x̄s... )
24
25
else
25
26
x̄s′ = map (x̄ -> functor (typeof (x), x̄)[1 ], x̄s)
26
27
x′, re = functor (typeof (x), x)
27
- xstate = map ((stᵢ, xᵢ, x̄sᵢ... ) -> update (o, stᵢ, xᵢ, x̄sᵢ... ), state, x′, x̄s′... )
28
+ xstate = map ((stᵢ, xᵢ, x̄sᵢ... ) -> update! (o, stᵢ, xᵢ, x̄sᵢ... ), state, x′, x̄s′... )
28
29
return map (first, xstate), re (map (last, xstate))
29
30
end
30
31
end
31
32
33
+ function update (o, state, x, x̄s... )
34
+ state′ = fmap (copy, state; exclude = iswriteable)
35
+ x′ = fmap (copy, x; exclude = iswriteable)
36
+ update! (o, state′, x′, x̄s... )
37
+ end
38
+
32
39
# default all rules to first order calls
33
- apply (o, state, x, dx, dxs... ) = apply (o, state, x, dx)
40
+ apply! (o, state, x, dx, dxs... ) = apply! (o, state, x, dx)
34
41
35
42
isnumeric (x:: AbstractArray{<:Number} ) = isleaf (x) # isleaf to allow for e.g. transposed shared weights
36
43
isnumeric (x:: AbstractArray{<:Bool} ) = false # convention of ChainRules is that Bool is non-differentiable
37
44
isnumeric (x) = false
45
+
46
+ iswriteable (:: DenseArray{<:AbstractFloat} ) = true # more elaborate versions are possible, wait until needed?
47
+ iswriteable (_) = false
48
+
49
+ """
50
+ @.. x = x + y
51
+ @.. x + y / z
52
+
53
+ Magic broadcasting macro, for use in `apply!` rules:
54
+ * Applied to assignment `x = ...` it is like `@.` unless `!iswriteable(x)`,
55
+ in which case it ignores `x`, and applies `@.` on the right.
56
+ * Applied to other expressions, it broadcasts like `@.` but does not materialise,
57
+ returning a `Broadcasted` object for later use.
58
+ """
59
+ macro var".." (ex)
60
+ if Meta. isexpr (ex, :(= ))
61
+ dst = esc (ex. args[1 ])
62
+ src = esc (Broadcast. __dot__ (ex. args[2 ]))
63
+ :(if $ iswriteable ($ dst)
64
+ $ dst .= $ src
65
+ else
66
+ $ src
67
+ end )
68
+ else
69
+ bc = esc (Broadcast. __dot__ (ex))
70
+ :($ lazy .($ bc))
71
+ end
72
+ end
73
+
74
+ function lazy end
75
+ Broadcast. broadcasted (:: typeof (lazy), x) = Lazy (x)
76
+ struct Lazy{T}; bc:: T ; end
77
+ Broadcast. materialize (x:: Lazy ) = Broadcast. instantiate (x. bc)
0 commit comments