|
1 | 1 | module Utilities |
2 | 2 |
|
| 3 | +include("math_mapper.jl") |
3 | 4 | include("plushalf.jl") |
4 | 5 | include("cache.jl") |
5 | 6 |
|
@@ -60,5 +61,47 @@ function unionall_type(::Type{T}) where {T} |
60 | 61 | return T.name.wrapper |
61 | 62 | end |
62 | 63 |
|
| 64 | +""" |
| 65 | + inferred_result_type(f, types...) |
| 66 | +
|
| 67 | +The type of result generated by calling `f` with arguments of the given types. |
| 68 | +An error is thrown if the compiler is unable to infer a concrete type, or if |
| 69 | +there is no method available for the specified argument types. |
| 70 | +""" |
| 71 | +function inferred_result_type(f::F, types...) where {F} |
| 72 | + if unrolled_any(==(Union{}), types) |
| 73 | + union_index = findfirst(==(Union{}), types) |
| 74 | + throw(ArgumentError("Cannot infer type of argument $union_index to $f")) |
| 75 | + end |
| 76 | + hasmethod(f, Tuple{types...}) || throw(MethodError(f, Tuple{types...})) |
| 77 | + inferred_type = Core.Compiler.return_type(f, Tuple{types...}) |
| 78 | + if !isconcretetype(inferred_type) |
| 79 | + types_string = join(map(Base.Fix1(*, "::"), types), ", ") |
| 80 | + throw(ArgumentError("Cannot infer concrete type of $f($types_string)")) |
| 81 | + end |
| 82 | + return inferred_type |
| 83 | +end |
| 84 | + |
| 85 | +""" |
| 86 | + inferred_result_value(f, types...) |
| 87 | +
|
| 88 | +The result of calling `f` with arguments of the given types, computed by using |
| 89 | +[`inferred_result_type`](@ref) in conjunction with a `Val` wrapper. The value |
| 90 | +returned by a function can only be inferred when it is a compile-time constant |
| 91 | +(i.e., when it is marked as a `Core.Const` in the output of `@code_warntype`). |
| 92 | +""" |
| 93 | +function inferred_result_value(f::F, types...) where {F} |
| 94 | + inferred_result_type(f, types...) # First check whether the type is inferred |
| 95 | + inferred_val_type = Core.Compiler.return_type(Val ∘ f, Tuple{types...}) |
| 96 | + if !isconcretetype(inferred_val_type) |
| 97 | + types_string = join(map(Base.Fix1(*, "::"), types), ", ") |
| 98 | + throw(ArgumentError("Cannot infer constant value of $f($types_string)")) |
| 99 | + end |
| 100 | + return val_type_parameter(inferred_val_type) |
| 101 | +end |
| 102 | + |
| 103 | +# Wrap values passed between functions in Vals to guarantee constant-propagation |
| 104 | +val_parameter(::Val{constant}) where {constant} = constant |
| 105 | +val_type_parameter(::Type{Val{constant}}) where {constant} = constant |
63 | 106 |
|
64 | 107 | end # module |
0 commit comments