Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions src/EnforcedTypeSignatureCallables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,21 @@ function return_type_enforcer(::Type{Return}) where {Return}
Base.Fix2(typeassert, Return)
end

function typed_callable_no_special_casing(callable::Callable, ::Type{Return}) where {
Return, Callable,
}
ret = return_type_enforcer(Return)
ret ∘ callable
end

function typed_callable_no_special_casing(callable::Callable, ::Type{Return}, ::Type{Arguments}) where {
Return, Arguments <: Tuple, Callable,
}
ret = return_type_enforcer(Return)
with_argument_types = CallableWithArgumentTypes{Arguments}(callable)
ret ∘ with_argument_types
end

"""
typed_callable(callable, return_type::Type, argument_types::Type{<:Tuple})::CallableWithTypeSignature{return_type, argument_types}

Expand All @@ -118,9 +133,11 @@ ERROR: TypeError: in typeassert, expected Tuple{Float32, Float32}, got a value o
function typed_callable(callable::Callable, ::Type{Return}, ::Type{Arguments}) where {
Return, Arguments <: Tuple, Callable,
}
ret = return_type_enforcer(Return)
with_argument_types = CallableWithArgumentTypes{Arguments}(callable)
ret ∘ with_argument_types
if callable isa CallableWithTypeSignature{Return, Arguments} # ensure idempotence
callable
else
typed_callable_no_special_casing(callable, Return, Arguments)
end
end

"""
Expand Down Expand Up @@ -151,8 +168,11 @@ ERROR: TypeError: in typeassert, expected Float32, got a value of type Float64
function typed_callable(callable::Callable, ::Type{Return}) where {
Return, Callable,
}
ret = return_type_enforcer(Return)
ret ∘ callable
if callable isa CallableWithReturnType{Return} # ensure idempotence
callable
else
typed_callable_no_special_casing(callable, Return)
end
end

end
10 changes: 10 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ using Aqua: Aqua
@test 3 === @inferred f(x_int)
@test_throws TypeError f(x_f64)
end
@testset "`typed_callable` idempotency" begin
@test let t = typed_callable(sin, Float32)
t === @inferred typed_callable(t, Float32)
end
end
end

@testset "`CallableWithTypeSignature`" begin
Expand Down Expand Up @@ -69,5 +74,10 @@ using Aqua: Aqua
@test 3 === @inferred f(3.0)
end
end
@testset "`typed_callable` idempotency" begin
@test let t = typed_callable(sin, Float32, Tuple{Float32})
t === @inferred typed_callable(t, Float32, Tuple{Float32})
end
end
end
end