Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions src/enzyme_ad/jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,64 @@ gentbl_cc_library(
],
)

td_library(
name = "PerfifyDialectFiles",
srcs = [
"Dialect/Perfify/Dialect.td",
"Dialect/Perfify/Ops.td",
],
deps = [
"@llvm-project//mlir:BuiltinDialectTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:SideEffectInterfaces",
"@llvm-project//mlir:ControlFlowInterfacesTdFiles",
],
)

gentbl_cc_library(
name = "PerfifyDialectIncGen",
tbl_outs = [
(
[
"-gen-dialect-decls",
"-dialect=perfify",
],
"Dialect/Perfify/PerfifyDialect.h.inc",
),
(
[
"-gen-dialect-defs",
"-dialect=perfify",
],
"Dialect/Perfify/PerfifyDialect.cpp.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "Dialect/Perfify/Dialect.td",
deps = [
":PerfifyDialectFiles"
],
)

gentbl_cc_library(
name = "PerfifyOpsIncGen",
tbl_outs = [
(
["-gen-op-decls"],
"Dialect/Perfify/PerfifyOps.h.inc",
),
(
["-gen-op-defs"],
"Dialect/Perfify/PerfifyOps.cpp.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "Dialect/Perfify/Ops.td",
deps = [
":PerfifyDialectFiles"
],
)

cc_library(
name = "CheckedRewrite",
hdrs = ["CheckedRewrite.h"],
Expand Down Expand Up @@ -721,6 +779,7 @@ cc_library(
"Dialect/*.cpp",
"Dialect/Distributed/*.cpp",
"Dialect/Tessera/*.cpp",
"Dialect/Perfify/*.cpp",
]) + [
"Utils.cpp",
],
Expand All @@ -730,6 +789,7 @@ cc_library(
"Dialect/*.h",
"Dialect/Distributed/*.h",
"Dialect/Tessera/*.h",
"Dialect/Perfify/*.h",
]) + [
"Utils.h",
],
Expand Down Expand Up @@ -758,6 +818,8 @@ cc_library(
":StablehloOptPatternsIncGen",
":TesseraDialectIncGen",
":TesseraOpsIncGen",
":PerfifyDialectIncGen",
":PerfifyOpsIncGen",
":chlo-derivatives",
":enzymexla-derivatives",
":mhlo-derivatives",
Expand Down
14 changes: 14 additions & 0 deletions src/enzyme_ad/jax/Dialect/Perfify/Dialect.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#include "Dialect.h"

#include "mlir/IR/Builders.h"
#include "llvm/ADT/TypeSwitch.h"

#include "src/enzyme_ad/jax/Dialect/Perfify/PerfifyDialect.cpp.inc"

// Initialize the dialect
void mlir::enzyme::perfify::PerfifyDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "src/enzyme_ad/jax/Dialect/Perfify/PerfifyOps.cpp.inc"
>();
}
19 changes: 19 additions & 0 deletions src/enzyme_ad/jax/Dialect/Perfify/Dialect.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#ifndef ENZYME_AD_JAX_DIALECT_PERFIFY_DIALECT_H
#define ENZYME_AD_JAX_DIALECT_PERFIFY_DIALECT_H

#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Region.h"
#include "mlir/IR/Types.h"

// Include the dialect
#include "src/enzyme_ad/jax/Dialect/Perfify/PerfifyDialect.h.inc"

// Operations
#define GET_OP_CLASSES
#include "src/enzyme_ad/jax/Dialect/Perfify/PerfifyOps.h.inc"

#endif // ENZYME_AD_JAX_DIALECT_PERFIFY_DIALECT_H
32 changes: 32 additions & 0 deletions src/enzyme_ad/jax/Dialect/Perfify/Dialect.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#ifndef ENZYME_AD_JAX_DIALECT_PERFIFY_DIALECT_TD
#define ENZYME_AD_JAX_DIALECT_PERFIFY_DIALECT_TD

include "mlir/IR/DialectBase.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/Traits.td"
include "mlir/IR/OpBase.td"

//===----------------------------------------------------------------------===//
// Perfify dialect definition.
//===----------------------------------------------------------------------===//

def PerfifyDialect : Dialect {
let name = "perfify";
let summary = "A dialect for specifying and proving runtime bounds";
let description = [{
Lets users specify a bound on the number of steps/latency (per a predefined cost model) that a function or other operation should take.
Leverages SAT solvers to automatically prove this, or interactive theorem provers to allow for complete proofs.
}];
let cppNamespace = "::mlir::enzyme::perfify";
}

//===----------------------------------------------------------------------===//
// Base Perfify operation definition.
//===----------------------------------------------------------------------===//

class PerfifyOp<string mnemonic, list<Trait> traits = []>
: Op<PerfifyDialect, mnemonic, traits>;

class PerfifyType<string name> : TypeDef<PerfifyDialect, name>; // may need to be modified

#endif // ENZYME_AD_JAX_DIALECT_PERFIFY_DIALECT_TD
15 changes: 15 additions & 0 deletions src/enzyme_ad/jax/Dialect/Perfify/Ops.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#include "mlir/IR/Builders.h"
#include "llvm/ADT/TypeSwitch.h"

#include "Dialect.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "mlir/Interfaces/FunctionInterfaces.h"

using namespace mlir;
using namespace mlir::enzyme::perfify;

namespace mlir::perfify {} // namespace mlir::perfify

#define GET_OP_CLASSES
#include "src/enzyme_ad/jax/Dialect/Perfify/PerfifyOps.cpp.inc"
50 changes: 50 additions & 0 deletions src/enzyme_ad/jax/Dialect/Perfify/Ops.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#ifndef ENZYME_AD_JAX_DIALECT_PERFIFY_OPS_TD
#define ENZYME_AD_JAX_DIALECT_PERFIFY_OPS_TD

include "mlir/IR/BuiltinAttributes.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "Dialect.td"

// Perfify.cost op
def CostOp : PerfifyOp<"cost", []> {
// summary
// description
// arguments
let arguments = (ins StrAttr:$target_op,
APIntAttr:$cycle_cost);
let assemblyFormat = "$target_op $cycle_cost attr-dict";

}

def ArgOp : PerfifyOp<"arg", []> {
let arguments = (ins I64Attr:$val);
let assemblyFormat = "$val attr-dict";
let results = (outs I64);
}

def AssumeOp : PerfifyOp<"assume", [HasParent<"ConditionsOp">, Terminator]> {
let arguments = (ins I1:$precondition);
let assemblyFormat = "$precondition attr-dict";
}

def ConditionsOp : PerfifyOp<"conditions", [HasParent<"AssumptionsOp">, Terminator]> {
let arguments = (ins FlatSymbolRefAttr:$func_handle,
BoolAttr:$verify_huh);
let regions = (region AnyRegion:$precondition, AnyRegion:$postcondition);

let assemblyFormat = [{
$func_handle $verify_huh attr-dict `pre`
$precondition
`post`
$postcondition
}];
}

def AssumptionsOp : PerfifyOp<"assumptions", [Terminator]> {
let regions = (region AnyRegion:$body);
let assemblyFormat = [{$body attr-dict}];
}

#endif // ENZYME_AD_JAX_DIALECT_PERFIFY_OPS_TD
2 changes: 2 additions & 0 deletions src/enzyme_ad/jax/RegistryUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
#include "src/enzyme_ad/jax/Passes/Passes.h"

#include "src/enzyme_ad/jax/Dialect/Distributed/Dialect.h"
#include "src/enzyme_ad/jax/Dialect/Perfify/Dialect.h"
#include "src/enzyme_ad/jax/Dialect/Tessera/Dialect.h"

#include "shardy/dialect/sdy/ir/dialect.h"
Expand Down Expand Up @@ -208,6 +209,7 @@ void registerDialects(mlir::DialectRegistry &registry) {
registry.insert<mlir::enzymexla::EnzymeXLADialect>();
registry.insert<mlir::enzyme::distributed::DistributedDialect>();
registry.insert<mlir::enzyme::tessera::TesseraDialect>();
registry.insert<mlir::enzyme::perfify::PerfifyDialect>();
registry.insert<mlir::sdy::SdyDialect>();
registry.insert<mlir::ub::UBDialect>();
registry.insert<mlir::triton::TritonDialect>();
Expand Down
26 changes: 26 additions & 0 deletions test/lit_tests/perfify/roundtrip.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
module {
func.func @foo() {func.return}
perfify.assumptions { // operation in the dialect
perfify.cost "arith.mul" 3 // op
perfify.cost "func.return" 0
perfify.cost "scf.yield" 0


perfify.conditions @foo true pre {
%b0 = perfify.arg 0 // op
%c0 = arith.constant 0
%cmp = arith.cmpi eq, %c0, %b0 : i64
perfify.assume %cmp
} post {
// %cost = perfify.fn_cost : perfify.cost
// %c9 = perfify.constant_cost 9 : perfify.cost // then our cost is 9
// %cmp = arith.cmpi eq, %cost, %c9
%b0 = perfify.arg 0 // op
%c0 = arith.constant 0
%cmp = arith.cmpi eq, %c0, %b0 : i64

perfify.assume %cmp
}

}
}