@@ -16,54 +16,37 @@ Notice:
1616
1717# we propagate `NotImplemented` (e.g., in `@scalar_rule`)
1818# this requires the following definitions (see also #337)
19- Base.:+ (x:: NotImplemented , :: ZeroTangent ) = x
20- Base.:+ (:: ZeroTangent , x:: NotImplemented ) = x
2119Base.:+ (x:: NotImplemented , :: NotImplemented ) = x
22- Base.:* (:: NotImplemented , :: ZeroTangent ) = ZeroTangent ()
23- Base.:* (:: ZeroTangent , :: NotImplemented ) = ZeroTangent ()
24- for T in (:NoTangent , :AbstractThunk , :Tangent , :Any )
20+ Base.:* (x:: NotImplemented , :: NotImplemented ) = x
21+ LinearAlgebra. dot (x:: NotImplemented , :: NotImplemented ) = x
22+ # `NotImplemented` always "wins" +
23+ for T in (:ZeroTangent , :NoTangent , :AbstractThunk , :Tangent , :Any )
2524 @eval Base.:+ (x:: NotImplemented , :: $T ) = x
2625 @eval Base.:+ (:: $T , x:: NotImplemented ) = x
26+ end
27+ # `NotImplemented` "loses" * and dot against NoTangent and ZeroTangent
28+ # this can be used to ignore partial derivatives that are not implemented
29+ for T in (:ZeroTangent , :NoTangent )
30+ @eval Base.:* (:: NotImplemented , :: $T ) = $ T ()
31+ @eval Base.:* (:: $T , :: NotImplemented ) = $ T ()
32+ @eval LinearAlgebra. dot (:: NotImplemented , :: $T ) = $ T ()
33+ @eval LinearAlgebra. dot (:: $T , :: NotImplemented ) = $ T ()
34+ end
35+ # `NotImplemented` "wins" * and dot for other types
36+ for T in (:AbstractThunk , :Tangent , :Any )
2737 @eval Base.:* (x:: NotImplemented , :: $T ) = x
38+ @eval Base.:* (:: $T , x:: NotImplemented ) = x
39+ @eval LinearAlgebra. dot (x:: NotImplemented , :: $T ) = x
40+ @eval LinearAlgebra. dot (:: $T , x:: NotImplemented ) = x
2841end
29- Base. muladd (x:: NotImplemented , y, z) = x
30- Base. muladd (:: NotImplemented , :: ZeroTangent , z) = z
31- Base. muladd (x:: NotImplemented , y, :: ZeroTangent ) = x
32- Base. muladd (:: NotImplemented , :: ZeroTangent , :: ZeroTangent ) = ZeroTangent ()
33- Base. muladd (x, y:: NotImplemented , z) = y
34- Base. muladd (:: ZeroTangent , :: NotImplemented , z) = z
35- Base. muladd (x, y:: NotImplemented , :: ZeroTangent ) = y
36- Base. muladd (:: ZeroTangent , :: NotImplemented , :: ZeroTangent ) = ZeroTangent ()
37- Base. muladd (x, y, z:: NotImplemented ) = z
38- Base. muladd (:: ZeroTangent , y, z:: NotImplemented ) = z
39- Base. muladd (x, :: ZeroTangent , z:: NotImplemented ) = z
40- Base. muladd (:: ZeroTangent , :: ZeroTangent , z:: NotImplemented ) = z
41- Base. muladd (x:: NotImplemented , :: NotImplemented , z) = x
42- Base. muladd (x:: NotImplemented , :: NotImplemented , :: ZeroTangent ) = x
43- Base. muladd (x:: NotImplemented , y, :: NotImplemented ) = x
44- Base. muladd (:: NotImplemented , :: ZeroTangent , z:: NotImplemented ) = z
45- Base. muladd (x, y:: NotImplemented , :: NotImplemented ) = y
46- Base. muladd (:: ZeroTangent , :: NotImplemented , z:: NotImplemented ) = z
47- Base. muladd (x:: NotImplemented , :: NotImplemented , :: NotImplemented ) = x
48- LinearAlgebra. dot (:: NotImplemented , :: ZeroTangent ) = ZeroTangent ()
49- LinearAlgebra. dot (:: ZeroTangent , :: NotImplemented ) = ZeroTangent ()
50-
51- # other common operations throw an exception
52- Base.:+ (x:: NotImplemented ) = throw (NotImplementedException (x))
42+
43+ # subtraction throws an exception: in AD we add tangents but do not subtract them
44+ # subtraction happens eg. in gradient descent which can't be performed with `NotImplemented`
5345Base.:- (x:: NotImplemented ) = throw (NotImplementedException (x))
54- Base.:- (x:: NotImplemented , :: ZeroTangent ) = throw (NotImplementedException (x))
55- Base.:- (:: ZeroTangent , x:: NotImplemented ) = throw (NotImplementedException (x))
5646Base.:- (x:: NotImplemented , :: NotImplemented ) = throw (NotImplementedException (x))
57- Base.:* (x:: NotImplemented , :: NotImplemented ) = throw (NotImplementedException (x))
58- function LinearAlgebra. dot (x:: NotImplemented , :: NotImplemented )
59- return throw (NotImplementedException (x))
60- end
61- for T in (:NoTangent , :AbstractThunk , :Tangent , :Any )
47+ for T in (:ZeroTangent , :NoTangent , :AbstractThunk , :Tangent , :Any )
6248 @eval Base.:- (x:: NotImplemented , :: $T ) = throw (NotImplementedException (x))
6349 @eval Base.:- (:: $T , x:: NotImplemented ) = throw (NotImplementedException (x))
64- @eval Base.:* (:: $T , x:: NotImplemented ) = throw (NotImplementedException (x))
65- @eval LinearAlgebra. dot (x:: NotImplemented , :: $T ) = throw (NotImplementedException (x))
66- @eval LinearAlgebra. dot (:: $T , x:: NotImplemented ) = throw (NotImplementedException (x))
6750end
6851
6952Base.:+ (:: NoTangent , :: NoTangent ) = NoTangent ()
0 commit comments