@@ -92,6 +92,21 @@ function return_type_enforcer(::Type{Return}) where {Return}
9292 Base. Fix2 (typeassert, Return)
9393end
9494
95+ function typed_callable_no_special_casing (callable:: Callable , :: Type{Return} ) where {
96+ Return, Callable,
97+ }
98+ ret = return_type_enforcer (Return)
99+ ret ∘ callable
100+ end
101+
102+ function typed_callable_no_special_casing (callable:: Callable , :: Type{Return} , :: Type{Arguments} ) where {
103+ Return, Arguments <: Tuple , Callable,
104+ }
105+ ret = return_type_enforcer (Return)
106+ with_argument_types = CallableWithArgumentTypes {Arguments} (callable)
107+ ret ∘ with_argument_types
108+ end
109+
95110"""
96111 typed_callable(callable, return_type::Type, argument_types::Type{<:Tuple})::CallableWithTypeSignature{return_type, argument_types}
97112
@@ -118,9 +133,11 @@ ERROR: TypeError: in typeassert, expected Tuple{Float32, Float32}, got a value o
118133function typed_callable (callable:: Callable , :: Type{Return} , :: Type{Arguments} ) where {
119134 Return, Arguments <: Tuple , Callable,
120135}
121- ret = return_type_enforcer (Return)
122- with_argument_types = CallableWithArgumentTypes {Arguments} (callable)
123- ret ∘ with_argument_types
136+ if callable isa CallableWithTypeSignature{Return, Arguments} # ensure idempotence
137+ callable
138+ else
139+ typed_callable_no_special_casing (callable, Return, Arguments)
140+ end
124141end
125142
126143"""
@@ -151,8 +168,11 @@ ERROR: TypeError: in typeassert, expected Float32, got a value of type Float64
151168function typed_callable (callable:: Callable , :: Type{Return} ) where {
152169 Return, Callable,
153170}
154- ret = return_type_enforcer (Return)
155- ret ∘ callable
171+ if callable isa CallableWithReturnType{Return} # ensure idempotence
172+ callable
173+ else
174+ typed_callable_no_special_casing (callable, Return)
175+ end
156176end
157177
158178end
0 commit comments