Skip to content

Commit fd9b469

Browse files
authored
fix: broadcasted type casting (#156)
* fix: broadcasted type casting * fix: optimisers compilation * chore: apply formatting * fix: handle case where the results get optimized out * test: add tests for type casting * fix: type restrict conversion to TypeCast
1 parent 51bd9b5 commit fd9b469

File tree

3 files changed

+38
-4
lines changed

3 files changed

+38
-4
lines changed

src/Compiler.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,13 @@ function create_result(tocopy::T, path, result_stores) where {T}
2929
end
3030

3131
function create_result(tocopy::ConcreteRArray{T,N}, path, result_stores) where {T,N}
32-
restore = result_stores[path]
33-
delete!(result_stores, path)
34-
return :(ConcreteRArray{$T,$N}($restore, $(tocopy.shape)))
32+
if haskey(result_stores, path)
33+
restore = result_stores[path]
34+
delete!(result_stores, path)
35+
return :(ConcreteRArray{$T,$N}($restore, $(tocopy.shape)))
36+
end
37+
# We will set the data for this later
38+
return :(ConcreteRArray{$T,$N}($(tocopy.data), $(tocopy.shape)))
3539
end
3640

3741
function create_result(tocopy::Array{T,N}, path, result_stores) where {T,N}
@@ -67,7 +71,9 @@ function create_result(tocopy::D, path, result_stores) where {K,V,D<:AbstractDic
6771
end
6872

6973
function create_result(
70-
tocopy::Union{Int,AbstractFloat,AbstractString,Nothing,Type,Symbol}, path, result_stores
74+
tocopy::Union{Integer,AbstractFloat,AbstractString,Nothing,Type,Symbol},
75+
path,
76+
result_stores,
7177
)
7278
return Meta.quot(tocopy)
7379
end

src/TracedRArray.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,18 @@ for (jlop, hloop) in (
338338
end
339339
end
340340

341+
struct TypeCast{T<:Number} <: Function end
342+
343+
function (::TypeCast{T})(x::TracedRArray{T2,0}) where {T,T2}
344+
return promote_to(TracedRArray{T,0}, x)
345+
end
346+
347+
elem_apply(::Type{T}, x::TracedRArray{T}) where {T<:Number} = x
348+
function elem_apply(::Type{T}, x::TracedRArray{T2}) where {T<:Number,T2<:Number}
349+
# Special Path to prevent going down a despecialized path
350+
return elem_apply(TypeCast{T}(), x)
351+
end
352+
341353
function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs}
342354
all(iszero ndims, args) && return f(args...)
343355

test/compile.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,20 @@ Base.sum(x::NamedTuple{(:a,),Tuple{T}}) where {T<:Reactant.TracedRArray} = (; a=
2727

2828
@test fworld(a, b) ones(2, 2) * 10
2929
end
30+
31+
@testset "type casting & optimized out returns" begin
32+
a = Reactant.ConcreteRArray(rand(2, 10))
33+
34+
ftype1(x) = Float64.(x)
35+
ftype2(x) = Float32.(x)
36+
37+
ftype1_compiled = @compile ftype1(a)
38+
ftype2_compiled = @compile ftype2(a)
39+
40+
@test ftype1_compiled(a) isa Reactant.ConcreteRArray{Float64,2}
41+
@test ftype2_compiled(a) isa Reactant.ConcreteRArray{Float32,2}
42+
43+
@test ftype1_compiled(a) Float64.(a)
44+
@test ftype2_compiled(a) Float32.(a)
45+
end
3046
end

0 commit comments

Comments
 (0)