Skip to content

Commit 0e4c548

Browse files
committed
Replace RecursiveApply interface with MathWrapper
1 parent 54df0a9 commit 0e4c548

File tree

4 files changed

+467
-32
lines changed

4 files changed

+467
-32
lines changed

docs/src/APIs/utilities_api.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,19 @@ Utilities.PlusHalf
99
Utilities.half
1010
```
1111

12+
## Utilities.MathMapper
13+
14+
```@docs
15+
Utilities.MathMapper
16+
Utilities.supports_math_mapper
17+
Utilities.nested_math_mapper
18+
Utilities.math_mapper_broadcast
19+
Utilities.reduce_math_mapper_broadcast
20+
Utilities.math_mapper_type_broadcast
21+
Utilities.reduce_math_mapper_type_broadcast
22+
Utilities.@math_mapper_method
23+
```
24+
1225
## Utilities.Cache
1326

1427
```@docs

src/Utilities/Utilities.jl

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module Utilities
22

3+
include("math_mapper.jl")
34
include("plushalf.jl")
45
include("cache.jl")
56

@@ -60,5 +61,47 @@ function unionall_type(::Type{T}) where {T}
6061
return T.name.wrapper
6162
end
6263

64+
"""
65+
inferred_result_type(f, types...)
66+
67+
The type of result generated by calling `f` with arguments of the given types.
68+
An error is thrown if the compiler is unable to infer a concrete type, or if
69+
there is no method available for the specified argument types.
70+
"""
71+
function inferred_result_type(f::F, types...) where {F}
72+
if unrolled_any(==(Union{}), types)
73+
union_index = findfirst(==(Union{}), types)
74+
throw(ArgumentError("Cannot infer type of argument $union_index to $f"))
75+
end
76+
hasmethod(f, Tuple{types...}) || throw(MethodError(f, Tuple{types...}))
77+
inferred_type = Core.Compiler.return_type(f, Tuple{types...})
78+
if !isconcretetype(inferred_type)
79+
types_string = join(map(Base.Fix1(*, "::"), types), ", ")
80+
throw(ArgumentError("Cannot infer concrete type of $f($types_string)"))
81+
end
82+
return inferred_type
83+
end
84+
85+
"""
86+
inferred_result_value(f, types...)
87+
88+
The result of calling `f` with arguments of the given types, computed by using
89+
[`inferred_result_type`](@ref) in conjunction with a `Val` wrapper. The value
90+
returned by a function can only be inferred when it is a compile-time constant
91+
(i.e., when it is marked as a `Core.Const` in the output of `@code_warntype`).
92+
"""
93+
function inferred_result_value(f::F, types...) where {F}
94+
inferred_result_type(f, types...) # First check whether the type is inferred
95+
inferred_val_type = Core.Compiler.return_type(Val f, Tuple{types...})
96+
if !isconcretetype(inferred_val_type)
97+
types_string = join(map(Base.Fix1(*, "::"), types), ", ")
98+
throw(ArgumentError("Cannot infer constant value of $f($types_string)"))
99+
end
100+
return val_type_parameter(inferred_val_type)
101+
end
102+
103+
# Wrap values passed between functions in Vals to guarantee constant-propagation
104+
val_parameter(::Val{constant}) where {constant} = constant
105+
val_type_parameter(::Type{Val{constant}}) where {constant} = constant
63106

64107
end # module

0 commit comments

Comments
 (0)