Skip to content

Commit 1dd24b8

Browse files
mofeingPangoraw
andauthored
Implement conj, conj! for TracedRArray (#169)
* Implement `conj` for `TracedRArray` * Generalize `adjoint` for `TracedRVecOrMat` * Implement `conj` for `TracedRNumber` * Implement `conj!` for `TracedRArray` * Implement `DenseElementsAttribute` for array of `Complex` * Implement `Base.real`, `Base.imag` for `TracedRNumber` * Fix typo * Implement `Base.real`, `Base.imag` for `TracedRArray` * Fix pointer length in `MLIR.IR.DenseElementsAttribute` on `Complex{T}` * Move complex tests to new file * Fix `ConcreteRArray` constructor on `Number` * Fix `to_rarray` on primitive number types * Fix `conj` tests on numbers * Write tests for `real`, `imag` * Remove duplicated method * Update src/TracedRNumber.jl Co-authored-by: Paul Berg <[email protected]> * fix `image` on `TracedRArray` of reals --------- Co-authored-by: Paul Berg <[email protected]>
1 parent 2f6a6d9 commit 1dd24b8

File tree

8 files changed

+179
-11
lines changed

8 files changed

+179
-11
lines changed

src/ConcreteRArray.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ mutable struct ConcreteRArray{T,N} <: RArray{T,N}
88
shape::NTuple{N,Int}
99
end
1010

11-
ConcreteRArray(data::T) where {T<:Number} = ConcreteRArray{T,0}(data, ())
11+
ConcreteRArray(data::T) where {T<:Number} = ConcreteRArray(fill(data))
1212

1313
Adapt.adapt_storage(::Type{T}, x::AbstractArray) where {T<:ConcreteRArray} = T(x)
1414

src/TracedRArray.jl

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,11 +188,62 @@ function Base.permutedims(A::AnyTracedRArray{T,N}, perm) where {T,N}
188188
)
189189
end
190190

191+
Base.conj(A::TracedRArray) = A
192+
function Base.conj(A::TracedRArray{T,N}) where {T<:Complex,N}
193+
return TracedRArray{T,N}(
194+
(),
195+
MLIR.IR.result(
196+
MLIR.Dialects.chlo.conj(
197+
A.mlir_data; result=mlir_type(TracedRArray{T,N}, size(A))
198+
),
199+
1,
200+
),
201+
size(A),
202+
)
203+
end
204+
205+
Base.conj!(A::TracedRArray) = A
206+
function Base.conj!(A::TracedRArray{T,N}) where {T<:Complex,N}
207+
A.mlir_data = MLIR.IR.result(
208+
MLIR.Dialects.chlo.conj(A.mlir_data; result=mlir_type(TracedRArray{T,N}, size(A))),
209+
1,
210+
)
211+
return A
212+
end
213+
214+
Base.real(A::TracedRArray) = A
215+
function Base.real(A::TracedRArray{Complex{T},N}) where {T,N}
216+
return TracedRArray{T,N}(
217+
(),
218+
MLIR.IR.result(
219+
MLIR.Dialects.stablehlo.real(
220+
A.mlir_data; result=mlir_type(TracedRArray{T,N}, size(A))
221+
),
222+
1,
223+
),
224+
size(A),
225+
)
226+
end
227+
228+
Base.imag(A::TracedRArray) = zero(A)
229+
function Base.imag(A::TracedRArray{Complex{T},N}) where {T,N}
230+
return TracedRArray{T,N}(
231+
(),
232+
MLIR.IR.result(
233+
MLIR.Dialects.stablehlo.imag(
234+
A.mlir_data; result=mlir_type(TracedRArray{T,N}, size(A))
235+
),
236+
1,
237+
),
238+
size(A),
239+
)
240+
end
241+
191242
function Base.transpose(A::AnyTracedRVecOrMat)
192243
A = ndims(A) == 1 ? reshape(A, :, 1) : A
193244
return permutedims(A, (2, 1))
194245
end
195-
Base.adjoint(A::AnyTracedRVecOrMat{<:Real}) = transpose(A)
246+
Base.adjoint(A::AnyTracedRVecOrMat) = conj(transpose(A))
196247

197248
function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N}
198249
if isa(rhs, TracedRArray)

src/TracedRNumber.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,36 @@ for (jlop, hloop) in (
199199
end
200200
end
201201

202+
Base.conj(x::TracedRNumber) = x
203+
function Base.conj(x::TracedRNumber{T}) where {T<:Complex}
204+
return TracedRNumber{T}(
205+
(),
206+
MLIR.IR.result(
207+
MLIR.Dialects.chlo.conj(x.mlir_data; result=mlir_type(TracedRNumber{T})), 1
208+
),
209+
)
210+
end
211+
212+
Base.real(x::TracedRNumber) = x
213+
function Base.real(x::TracedRNumber{Complex{T}}) where {T}
214+
return TracedRNumber{T}(
215+
(),
216+
MLIR.IR.result(
217+
MLIR.Dialects.stablehlo.real(x.mlir_data; result=mlir_type(TracedRNumber{T})), 1
218+
),
219+
)
220+
end
221+
222+
Base.imag(x::TracedRNumber) = zero(x)
223+
function Base.imag(x::TracedRNumber{Complex{T}}) where {T}
224+
return TracedRNumber{T}(
225+
(),
226+
MLIR.IR.result(
227+
MLIR.Dialects.stablehlo.imag(x.mlir_data; result=mlir_type(TracedRNumber{T})), 1
228+
),
229+
)
230+
end
231+
202232
# XXX: Enzyme-MLIR doesn't have `abs` adjoint defined
203233
Base.abs2(x::TracedRNumber{<:Real}) = x^2
204234

src/Tracing.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,3 +492,5 @@ end
492492
@inline function to_rarray(@nospecialize(x))
493493
return make_tracer(OrderedIdDict(), x, (), Reactant.ArrayToConcrete)
494494
end
495+
496+
to_rarray(x::ReactantPrimitive) = ConcreteRArray(x)

src/mlir/IR/Attribute.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -567,8 +567,10 @@ end
567567
function DenseElementsAttribute(values::AbstractArray{<:Complex})
568568
shaped_type = TensorType(size(values), Type(eltype(values)))
569569
# TODO: row major
570-
Attribute(
571-
API.mlirDenseElementsAttrRawBufferGet(shaped_type, length(values) * sizeof(eltype(values)), values)
570+
return Attribute(
571+
API.mlirDenseElementsAttrRawBufferGet(
572+
shaped_type, length(values) * Base.elsize(values), values
573+
),
572574
)
573575
end
574576

src/mlir/IR/Type.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -85,13 +85,6 @@ Creates a signless integer type of the given bitwidth in the context. The type i
8585
Type(T::Core.Type{<:Integer}; context::Context=context()) =
8686
Type(API.mlirIntegerTypeGet(context, sizeof(T) * 8))
8787

88-
"""
89-
Type(T::Core.Type{<:Complex}; context=context())
90-
91-
Creates a complex type with the given element type.
92-
"""
93-
Type(T::Core.Type{<:Complex}; context=context()) = Type(API.mlirComplexTypeGet(Type(T(im) |> real |> typeof)))
94-
9588
"""
9689
Type(T::Core.Type{<:Signed}; context=context()
9790

test/complex.jl

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
using Test
2+
using Reactant
3+
4+
@testset "conj" begin
5+
@testset "$(typeof(x))" for x in (1.0, 1.0 + 2.0im)
6+
x_concrete = Reactant.to_rarray(x)
7+
f = @compile conj(x_concrete)
8+
@test only(f(x_concrete)) == conj(x)
9+
end
10+
11+
@testset "$(typeof(x))" for x in (
12+
fill(1.0 + 2.0im),
13+
fill(1.0),
14+
[1.0 + 2.0im; 3.0 + 4.0im],
15+
[1.0; 3.0],
16+
[1.0 + 2.0im 3.0 + 4.0im],
17+
[1.0 2.0],
18+
[1.0+2.0im 3.0+4.0im; 5.0+6.0im 7.0+8.0im],
19+
[1.0 3.0; 5.0 7.0],
20+
)
21+
x_concrete = Reactant.to_rarray(x)
22+
f = @compile conj(x_concrete)
23+
@test f(x_concrete) == conj(x)
24+
end
25+
end
26+
27+
@testset "conj!" begin
28+
@testset "$(typeof(x))" for x in (
29+
fill(1.0 + 2.0im),
30+
fill(1.0),
31+
[1.0 + 2.0im; 3.0 + 4.0im],
32+
[1.0; 3.0],
33+
[1.0 + 2.0im 3.0 + 4.0im],
34+
[1.0 2.0],
35+
[1.0+2.0im 3.0+4.0im; 5.0+6.0im 7.0+8.0im],
36+
[1.0 3.0; 5.0 7.0],
37+
)
38+
x_concrete = Reactant.to_rarray(x)
39+
f = @compile conj!(x_concrete)
40+
@test f(x_concrete) == conj(x)
41+
@test x_concrete == conj(x)
42+
end
43+
end
44+
45+
@testset "real" begin
46+
@testset "$(typeof(x))" for x in (1.0, 1.0 + 2.0im)
47+
x_concrete = Reactant.to_rarray(x)
48+
f = @compile real(x_concrete)
49+
@test only(f(x_concrete)) == real(x)
50+
end
51+
52+
@testset "$(typeof(x))" for x in (
53+
fill(1.0 + 2.0im),
54+
fill(1.0),
55+
[1.0 + 2.0im; 3.0 + 4.0im],
56+
[1.0; 3.0],
57+
[1.0 + 2.0im 3.0 + 4.0im],
58+
[1.0 2.0],
59+
[1.0+2.0im 3.0+4.0im; 5.0+6.0im 7.0+8.0im],
60+
[1.0 3.0; 5.0 7.0],
61+
)
62+
x_concrete = Reactant.to_rarray(x)
63+
f = @compile real(x_concrete)
64+
@test f(x_concrete) == real(x)
65+
end
66+
end
67+
68+
@testset "imag" begin
69+
@testset "$(typeof(x))" for x in (1.0, 1.0 + 2.0im)
70+
x_concrete = Reactant.to_rarray(x)
71+
f = @compile imag(x_concrete)
72+
@test only(f(x_concrete)) == imag(x)
73+
end
74+
75+
@testset "$(typeof(x))" for x in (
76+
fill(1.0 + 2.0im),
77+
fill(1.0),
78+
[1.0 + 2.0im; 3.0 + 4.0im],
79+
[1.0; 3.0],
80+
[1.0 + 2.0im 3.0 + 4.0im],
81+
[1.0 2.0],
82+
[1.0+2.0im 3.0+4.0im; 5.0+6.0im 7.0+8.0im],
83+
[1.0 3.0; 5.0 7.0],
84+
)
85+
x_concrete = Reactant.to_rarray(x)
86+
f = @compile imag(x_concrete)
87+
@test f(x_concrete) == imag(x)
88+
end
89+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))
4646
@safetestset "Layout" include("layout.jl")
4747
@safetestset "Tracing" include("tracing.jl")
4848
@safetestset "Basic" include("basic.jl")
49+
@safetestset "Complex" include("complex.jl")
4950
@safetestset "Broadcast" include("bcast.jl")
5051
@safetestset "Struct" include("struct.jl")
5152
@safetestset "Closure" include("closure.jl")

0 commit comments

Comments
 (0)