Skip to content

Commit 35a51ad

Browse files
authored
fix primitive_type for complex (#193)
* fix primitive_type for complex * add simple complex runtime test
1 parent c507e37 commit 35a51ad

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

src/XLA.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,8 +233,8 @@ end
233233

234234
@inline primitive_type(::Type{Float64}) = 12
235235

236-
@inline primitive_type(::Type{Complex{Float32}}) = 24
237-
@inline primitive_type(::Type{Complex{Float64}}) = 25
236+
@inline primitive_type(::Type{Complex{Float32}}) = 15
237+
@inline primitive_type(::Type{Complex{Float64}}) = 18
238238

239239
function ArrayFromHostBuffer(client::Client, array::Array{T,N}, device) where {T,N}
240240
sizear = Int64[s for s in reverse(size(array))]

test/basic.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,3 +438,10 @@ end
438438
@test size(f(y)) == size(x)
439439
@test eltype(f(y)) == eltype(x)
440440
end
441+
442+
@testset "Complex runtime: $CT" for CT in (ComplexF32, ComplexF64)
443+
a = Reactant.to_rarray(ones(CT, 2))
444+
b = Reactant.to_rarray(ones(CT, 2))
445+
c = Reactant.compile(+, (a, b))(a, b)
446+
@test c == ones(CT, 2) + ones(CT, 2)
447+
end

0 commit comments

Comments
 (0)