Skip to content

Commit 1585483

Browse files
committed
refactored generate op
1 parent 3d66c7a commit 1585483

File tree

3 files changed

+306
-48
lines changed

3 files changed

+306
-48
lines changed

deps/ReactantExtra/API.cpp

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -368,26 +368,17 @@ extern "C" MLIR_CAPI_EXPORTED MlirType enzymeTraceTypeGet(MlirContext ctx) {
368368
return wrap(mlir::enzyme::TraceType::get(unwrap(ctx)));
369369
}
370370

371+
extern "C" MLIR_CAPI_EXPORTED MlirType
372+
enzymeConstraintTypeGet(MlirContext ctx) {
373+
return wrap(mlir::enzyme::ConstraintType::get(unwrap(ctx)));
374+
}
375+
371376
extern "C" MLIR_CAPI_EXPORTED MlirAttribute
372377
enzymeSymbolAttrGet(MlirContext ctx, uint64_t symbol) {
373378
mlir::Attribute attr = mlir::enzyme::SymbolAttr::get(unwrap(ctx), symbol);
374379
return wrap(attr);
375380
}
376381

377-
extern "C" MLIR_CAPI_EXPORTED MlirAttribute enzymeConstraintAttrGet(
378-
MlirContext ctx, uint64_t symbol, MlirAttribute values) {
379-
mlir::Attribute vals = unwrap(values);
380-
auto arr = llvm::dyn_cast<mlir::ArrayAttr>(vals);
381-
if (!arr) {
382-
ReactantThrowError(
383-
"enzymeConstraintAttrGet: `values` must be an ArrayAttr");
384-
return MlirAttribute{nullptr};
385-
}
386-
mlir::Attribute attr =
387-
mlir::enzyme::ConstraintAttr::get(unwrap(ctx), symbol, arr);
388-
return wrap(attr);
389-
}
390-
391382
// Create profiler session and start profiling
392383
extern "C" tsl::ProfilerSession *
393384
CreateProfilerSession(uint32_t device_tracer_level,

0 commit comments

Comments
 (0)