@@ -40,6 +40,8 @@ Base.propertynames(p::ProjectTo) = propertynames(backing(p))
4040backing (project:: ProjectTo ) = getfield (project, :info )
4141
4242project_type (p:: ProjectTo{T} ) where {T} = T
43+ project_type (:: Type{<:ProjectTo{T}} ) where {T} = T
44+ project_type (_) = Any
4345
4446function Base. show (io:: IO , project:: ProjectTo{T} ) where {T}
4547 print (io, " ProjectTo{" )
@@ -142,42 +144,16 @@ end
142144# dx::AbstractArray (when both are possible), or the reverse. So for now we just pass them through:
143145(:: ProjectTo{T} )(dx:: Tangent{<:T} ) where {T} = dx
144146
145- # ####
146- # #### A related utility which wants to live nearby
147- # ####
148-
149- """
150- is_non_differentiable(x) == is_non_differentiable(typeof(x))
151-
152- Returns `true` if `x` is known from its type not to have derivatives, else `false`.
153-
154- Should mostly agree with whether `ProjectTo(x)` maps to `AbstractZero`,
155- which is what the fallback method checks. The exception is that it will not look
156- inside abstractly typed containers like `x = Any[true, false]`.
157- """
158- is_non_differentiable (x) = is_non_differentiable (typeof (x))
159-
160- is_non_differentiable (:: Type{<:Number} ) = false
161- is_non_differentiable (:: Type{<:NTuple{N,T}} ) where {N,T} = is_non_differentiable (T)
162- is_non_differentiable (:: Type{<:AbstractArray{T}} ) where {T} = is_non_differentiable (T)
163-
164- function is_non_differentiable (:: Type{T} ) where {T} # fallback
165- PT = Base. _return_type (ProjectTo, Tuple{T}) # might be Union{} if unstable
166- return isconcretetype (PT) && PT <: ProjectTo{<:AbstractZero}
167- end
168-
169147# ####
170148# #### `Base`
171149# ####
172150
173151# Bool
174152ProjectTo (:: Bool ) = ProjectTo {NoTangent} () # same projector as ProjectTo(::AbstractZero) above
175- is_non_differentiable (:: Type{Bool} ) = true
176153
177154# Other never-differentiable types
178- for T in (:Symbol , :Char , :AbstractString , :RoundingMode , :IndexStyle )
155+ for T in (:Symbol , :Char , :AbstractString , :RoundingMode , :IndexStyle , :Nothing )
179156 @eval ProjectTo (:: $T ) = ProjectTo {NoTangent} ()
180- @eval is_non_differentiable (:: Type{<:$T} ) = true
181157end
182158
183159# Numbers
@@ -627,3 +603,40 @@ function (project::ProjectTo{SparseMatrixCSC})(dx::SparseMatrixCSC)
627603 invoke (project, Tuple{AbstractArray}, dx)
628604 end
629605end
606+
607+ # ####
608+ # #### A related utility which wants to live nearby
609+ # ####
610+
611+ """
612+ differential_type(x)
613+ differential_type(typeof(x))
614+
615+ Testing `differential_type(x) <: AbstractZero` will tell you whether `x` is
616+ known to be non-differentiable.
617+
618+ This relies on `ProjectTo(x)`, and the method accepting a type relies on type inference.
619+ Thus it will not look inside abstractly typed containers such as `x = Any[true, false]`.
620+
621+ ```jldoctest
622+ julia> differential_type(true)
623+ NoTangent
624+
625+ julia> differential_type(Int)
626+ Float64
627+
628+ julia> x = Any[true, false];
629+
630+ julia> differential_type(x)
631+ NoTangent
632+
633+ julia> differential_type(typeof(x))
634+ Any
635+ ```
636+ """
637+ differential_type (x) = project_type (ProjectTo (x))
638+
639+ function differential_type (:: Type{T} ) where {T}
640+ PT = Base. _return_type (ProjectTo, Tuple{T}) # might be Union{} if unstable
641+ return isconcretetype (PT) ? project_type (PT) : Any
642+ end
0 commit comments