Skip to content

Commit 755736b

Browse files
Add C-API for constructing Complex Attributes (#208)
* Add C-API to construct ComplexF32 and ComplexF64 attributes * Link against MLIR Complex dialect * Rename C functions and delete `float` version * Add the Julia bindings to the C routines * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent fce399c commit 755736b

File tree

4 files changed

+51
-0
lines changed

4 files changed

+51
-0
lines changed

deps/ReactantExtra/API.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1616
#include "mlir/Dialect/Arith/IR/Arith.h"
1717
#include "mlir/Dialect/Async/IR/Async.h"
18+
#include "mlir/Dialect/Complex/IR/Complex.h"
1819
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
1920
#include "mlir/Dialect/DLTI/DLTI.h"
2021
#include "mlir/Dialect/Func/Extensions/InlinerExtension.h"
@@ -63,6 +64,19 @@ using namespace mlir;
6364
using namespace llvm;
6465
using namespace xla;
6566

67+
// MLIR C-API extras
68+
#pragma region MLIR Extra
69+
MLIR_CAPI_EXPORTED MlirAttribute mlirComplexAttrDoubleGet(MlirContext ctx, MlirType type, double real, double imag) {
70+
return wrap(complex::NumberAttr::get(unwrap(type), real, imag));
71+
}
72+
73+
MLIR_CAPI_EXPORTED MlirAttribute mlirComplexAttrDoubleGetChecked(MlirLocation loc, MlirType type, double real, double imag) {
74+
return wrap(complex::NumberAttr::getChecked(unwrap(loc), unwrap(type), unwrap(type), real, imag));
75+
}
76+
77+
MlirTypeID mlirComplexAttrGetTypeID(void) { return wrap(complex::NumberAttr::getTypeID()); }
78+
#pragma endregion
79+
6680
// int google::protobuf::io::CodedInputStream::default_recursion_limit_ = 100;
6781
// int xla::_LayoutProto_default_instance_;
6882

deps/ReactantExtra/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,7 @@ cc_library(
308308
"@llvm-project//mlir:AllPassesAndDialects",
309309
"@llvm-project//mlir:ArithDialect",
310310
"@llvm-project//mlir:AsyncDialect",
311+
"@llvm-project//mlir:ComplexDialect",
311312
"@llvm-project//mlir:ControlFlowDialect",
312313
"@llvm-project//mlir:ConversionPasses",
313314
"@llvm-project//mlir:DLTIDialect",

src/mlir/IR/Attribute.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,29 @@ function Base.Float64(attr::Attribute)
133133
return API.mlirFloatAttrGetValueDouble(attr)
134134
end
135135

136+
"""
137+
Attribute(complex; context=context(), location=Location(), check=false)
138+
139+
Creates a complex attribute in the given context with the given complex value and double-precision FP semantics.
140+
"""
141+
function Attribute(
142+
c::T; context::Context=context(), location::Location=Location(), check::Bool=false
143+
) where {T<:Complex}
144+
if check
145+
Attribute(
146+
API.mlirComplexAttrDoubleGetChecked(
147+
location, Type(T), Float64(real(c)), Float64(imag(c))
148+
),
149+
)
150+
else
151+
Attribute(
152+
API.mlirComplexAttrDoubleGet(
153+
context, Type(T), Float64(real(c)), Float64(imag(c))
154+
),
155+
)
156+
end
157+
end
158+
136159
"""
137160
isinteger(attr)
138161

src/mlir/MLIR.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,19 @@ module API
1111
let
1212
include("libMLIR_h.jl")
1313
end
14+
15+
# MLIR C API - extra
16+
function mlirComplexAttrDoubleGet(ctx, type, real, imag)
17+
@ccall mlir_c.mlirComplexAttrDoubleGet(
18+
ctx::MlirContext, type::MlirType, real::Cdouble, imag::Cdouble
19+
)::MlirAttribute
20+
end
21+
22+
function mlirComplexAttrDoubleGetChecked(loc, type, real, imag)
23+
@ccall mlir_c.mlirComplexAttrDoubleGetChecked(
24+
loc::MlirLocation, type::MlirType, real::Cdouble, imag::Cdouble
25+
)::MlirAttribute
26+
end
1427
end # module API
1528

1629
include("IR/IR.jl")

0 commit comments

Comments
 (0)