Skip to content

Commit 73680d2

Browse files
authored
Raiselib (#1022)
* raise lib * add files * builds * More * fix * fixup * more importing * raising * more * fix * fix * fmt * fix * more lowering * launch * fmt
1 parent 4a19bc9 commit 73680d2

26 files changed

+11387
-99
lines changed

BUILD

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ cc_binary(
4444
"@llvm-project//mlir:GPUToLLVMIRTranslation",
4545
"@llvm-project//mlir:LLVMToLLVMIRTranslation",
4646
"@llvm-project//mlir:NVVMToLLVMIRTranslation",
47+
"@tsl//tsl/platform:env",
48+
"@tsl//tsl/platform:env_impl",
4749
] + if_llvm_aarch32_available([
4850
"@llvm-project//llvm:ARMAsmParser",
4951
"@llvm-project//llvm:ARMCodeGen",
@@ -66,6 +68,53 @@ cc_binary(
6668
],
6769
)
6870

71+
cc_library(
72+
name = "RaiseLib",
73+
srcs = [
74+
"//src/enzyme_ad/jax:raise.cpp",
75+
"//src/enzyme_ad/jax:RegistryUtils.cpp",
76+
],
77+
visibility = ["//visibility:public"],
78+
deps = [
79+
"@llvm-project//mlir:MlirOptLib",
80+
"//src/enzyme_ad/jax:RegistryUtils",
81+
"@llvm-project//mlir:GPUToLLVMIRTranslation",
82+
"@llvm-project//mlir:LLVMToLLVMIRTranslation",
83+
"@llvm-project//mlir:NVVMToLLVMIRTranslation",
84+
"@tsl//tsl/platform:env",
85+
"@tsl//tsl/platform:env_impl",
86+
] + if_llvm_aarch32_available([
87+
"@llvm-project//llvm:ARMAsmParser",
88+
"@llvm-project//llvm:ARMCodeGen",
89+
]) + if_llvm_aarch64_available([
90+
"@llvm-project//llvm:AArch64AsmParser",
91+
"@llvm-project//llvm:AArch64CodeGen",
92+
]) + if_llvm_powerpc_available([
93+
"@llvm-project//llvm:PowerPCAsmParser",
94+
"@llvm-project//llvm:PowerPCCodeGen",
95+
]) + if_llvm_system_z_available([
96+
"@llvm-project//llvm:SystemZAsmParser",
97+
"@llvm-project//llvm:SystemZCodeGen",
98+
]) + if_llvm_x86_available([
99+
"@llvm-project//llvm:X86AsmParser",
100+
"@llvm-project//llvm:X86CodeGen",
101+
]),
102+
alwayslink = True,
103+
linkstatic = True,
104+
copts = [
105+
"-Wno-unused-variable",
106+
"-Wno-return-type",
107+
],
108+
)
109+
110+
# cc_shared_library(
111+
cc_binary(
112+
name = "libRaise.so",
113+
linkshared = 1, ## important
114+
linkstatic = 1, ## important
115+
deps = [":RaiseLib"],
116+
)
117+
69118
cc_binary(
70119
name = "enzymexlamlir-tblgen",
71120
srcs = ["//src/enzyme_ad/tools:enzymexlamlir-tblgen.cpp"],

src/enzyme_ad/jax/BUILD

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ load("@rules_python//python:py_library.bzl", "py_library")
55

66
exports_files([
77
"enzymexlamlir-opt.cpp",
8+
"raise.cpp",
89
"RegistryUtils.cpp",
910
])
1011

@@ -387,8 +388,10 @@ gentbl_cc_library(
387388
td_file = "Dialect/EnzymeXLAOps.td",
388389
deps = [
389390
":EnzymeXLADialectTdFiles",
391+
"@llvm-project//mlir:CopyOpInterfaceTdFiles",
390392
"@enzyme//:EnzymeDialectTdFiles",
391393
"@stablehlo//:stablehlo_ops_td_files",
394+
"@llvm-project//mlir:GPUOpsTdFiles",
392395
],
393396
)
394397

@@ -443,6 +446,7 @@ cc_library(
443446
srcs = glob(
444447
[
445448
"Implementations/*.cpp",
449+
"Utils.cpp",
446450
"Passes/*.cpp",
447451
"Dialect/*.cpp",
448452
],
@@ -487,6 +491,7 @@ cc_library(
487491
"@llvm-project//llvm:Support",
488492
"@llvm-project//mlir:AffineAnalysis",
489493
"@llvm-project//mlir:AffineDialect",
494+
"@llvm-project//mlir:AsyncDialect",
490495
"@llvm-project//mlir:AffineToStandard",
491496
"@llvm-project//mlir:AffineTransforms",
492497
"@llvm-project//mlir:AffineUtils",
@@ -495,6 +500,7 @@ cc_library(
495500
"@llvm-project//mlir:ArithToLLVM",
496501
"@llvm-project//mlir:ArithUtils",
497502
"@llvm-project//mlir:BytecodeOpInterface",
503+
"@llvm-project//mlir:CopyOpInterface",
498504
"@llvm-project//mlir:CallOpInterfaces",
499505
"@llvm-project//mlir:CommonFolders",
500506
"@llvm-project//mlir:ComplexDialect",
@@ -531,6 +537,7 @@ cc_library(
531537
"@llvm-project//mlir:NVGPUDialect",
532538
"@llvm-project//mlir:NVGPUToNVVM",
533539
"@llvm-project//mlir:NVVMDialect",
540+
"@llvm-project//mlir:ROCDLDialect",
534541
"@llvm-project//mlir:NVVMToLLVM",
535542
"@llvm-project//mlir:OpenMPDialect",
536543
"@llvm-project//mlir:OpenMPToLLVM",

src/enzyme_ad/jax/Dialect/Dialect.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,50 @@
1111
#include "Ops.h"
1212
#include "mlir/IR/DialectImplementation.h"
1313

14+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1415
#include "mlir/IR/Builders.h"
1516
#include "llvm/ADT/TypeSwitch.h"
17+
#include "llvm/Support/InterleavedRange.h"
18+
#include "llvm/Support/LogicalResult.h"
1619

1720
#include "mlir/IR/Dialect.h"
1821
#include "mlir/Transforms/InliningUtils.h"
1922

2023
// #include "Dialect/EnzymeEnums.cpp.inc"
2124
#include "src/enzyme_ad/jax/Dialect/EnzymeXLADialect.cpp.inc"
2225

26+
static llvm::ParseResult parseAsyncDependencies(
27+
mlir::OpAsmParser &parser, mlir::Type &asyncTokenType,
28+
llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand>
29+
&asyncDependencies) {
30+
using namespace mlir;
31+
using namespace mlir::gpu;
32+
auto loc = parser.getCurrentLocation();
33+
if (succeeded(parser.parseOptionalKeyword("async"))) {
34+
if (parser.getNumResults() == 0)
35+
return parser.emitError(loc, "needs to be named when marked 'async'");
36+
asyncTokenType = parser.getBuilder().getType<AsyncTokenType>();
37+
}
38+
return parser.parseOperandList(asyncDependencies,
39+
OpAsmParser::Delimiter::OptionalSquare);
40+
}
41+
42+
/// Prints optional async dependencies with its leading keyword.
43+
/// (`async`)? (`[` ssa-id-list `]`)?
44+
// Used by the tablegen assembly format for several async ops.
45+
static void printAsyncDependencies(mlir::OpAsmPrinter &printer,
46+
mlir::Operation *op,
47+
mlir::Type asyncTokenType,
48+
mlir::OperandRange asyncDependencies) {
49+
if (asyncTokenType)
50+
printer << "async";
51+
if (asyncDependencies.empty())
52+
return;
53+
if (asyncTokenType)
54+
printer << ' ';
55+
printer << llvm::interleaved_array(asyncDependencies);
56+
}
57+
2358
#define GET_OP_CLASSES
2459
#include "src/enzyme_ad/jax/Dialect/EnzymeXLAOps.cpp.inc"
2560

src/enzyme_ad/jax/Dialect/EnzymeXLAOps.td

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
include "Enzyme/MLIR/Dialect/Dialect.td"
1313
include "Dialect.td"
1414

15+
16+
include "mlir/Interfaces/CopyOpInterface.td"
1517
include "mlir/Interfaces/ViewLikeInterface.td"
1618
include "mlir/IR/SymbolInterfaces.td"
1719
include "mlir/IR/EnumAttr.td"
@@ -26,6 +28,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
2628
include "mlir/Interfaces/CallInterfaces.td"
2729
include "mlir/Interfaces/InferTypeOpInterface.td"
2830
include "stablehlo/dialect/Base.td"
31+
include "mlir/Dialect/GPU/IR/GPUBase.td"
2932

3033
def TensorI64 : Type<CPred<"::llvm::isa<::mlir::TensorType>($_self) && ::llvm::cast<::mlir::TensorType>($_self).getShape().size() == 0 && ::llvm::cast<::mlir::TensorType>($_self).getElementType().isSignlessInteger(64)">, "tensor<i64>",
3134
"::mlir::TensorType">,
@@ -62,6 +65,43 @@ def KernelCallOp: EnzymeXLA_Op<"kernel_call", [DeclareOpInterfaceMethods<SymbolU
6265
let hasCanonicalizer = 1;
6366
}
6467

68+
def MemcpyOp : EnzymeXLA_Op<"memcpy", [CopyOpInterface]> {
69+
70+
let summary = "GPU memcpy operation";
71+
72+
let description = [{
73+
The `gpu.memcpy` operation copies the content of one memref to another.
74+
75+
The op does not execute before all async dependencies have finished
76+
executing.
77+
78+
If the `async` keyword is present, the op is executed asynchronously (i.e.
79+
it does not block until the execution has finished on the device). In
80+
that case, it returns a !gpu.async.token.
81+
82+
Example:
83+
84+
```mlir
85+
%token = gpu.memcpy async [%dep] %dst, %src : memref<?xf32, 1>, memref<?xf32>
86+
```
87+
}];
88+
89+
let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
90+
Arg<AnyMemRef, "", [MemWriteAt<0, FullEffect>]>:$target,
91+
Arg<AnyMemRef, "", [MemReadAt<0, FullEffect>]>:$source,
92+
Index:$size
93+
);
94+
let results = (outs Optional<GPU_AsyncToken>:$asyncToken);
95+
96+
let assemblyFormat = [{
97+
custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
98+
$target`,` $source `,` $size `:` type($target)`,` type($source) attr-dict
99+
}];
100+
let hasFolder = 1;
101+
let hasVerifier = 1;
102+
let hasCanonicalizer = 1;
103+
}
104+
65105
def JITCallOp: EnzymeXLA_Op<"jit_call", [DeclareOpInterfaceMethods<SymbolUserOpInterface>, DeclareOpInterfaceMethods<CallOpInterface>, DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
66106
let summary = "JIT Call operation";
67107

@@ -92,6 +132,97 @@ def GetStreamOp : EnzymeXLA_Op<"get_stream", [Pure]> {
92132
let results = (outs AnyType:$result);
93133
}
94134

135+
136+
def GPUWrapperOp : EnzymeXLA_Op<"gpu_wrapper", [
137+
RecursiveMemoryEffects,
138+
AutomaticAllocationScope,
139+
SingleBlockImplicitTerminator<"enzymexla::PolygeistYieldOp">]> {
140+
let arguments = (ins Variadic<Index>:$blockDims);
141+
let summary = "Indicates the region contained must be executed on the GPU";
142+
let description = [{
143+
The optional arguments to this operation are suggestions about what block
144+
dimensions this gpu kernel should have - usually taken from kernel launch
145+
params
146+
}];
147+
let results = (outs Index : $result);
148+
let regions = (region SizedRegion<1>:$region);
149+
let skipDefaultBuilders = 1;
150+
let builders = [
151+
OpBuilder<(ins "ValueRange":$blockSizes)>,
152+
OpBuilder<(ins)>];
153+
}
154+
155+
def GPUErrorOp : EnzymeXLA_Op<"gpu_error", [
156+
RecursiveMemoryEffects,
157+
SingleBlockImplicitTerminator<"enzymexla::PolygeistYieldOp">]>,
158+
Arguments<(ins)> {
159+
let summary = "Gets the error returned by the gpu operation inside";
160+
// TODO should be i32, not index
161+
let results = (outs Index : $result);
162+
let regions = (region SizedRegion<1>:$region);
163+
let skipDefaultBuilders = 1;
164+
let builders = [OpBuilder<(ins)>];
165+
166+
}
167+
168+
def NoopOp
169+
: EnzymeXLA_Op<"noop",
170+
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
171+
let summary = "Noop for preventing folding or transformations";
172+
let arguments = (ins Variadic<Index>:$blockDims);
173+
let skipDefaultBuilders = 1;
174+
let builders = [
175+
OpBuilder<(ins "ValueRange":$indices)>];
176+
let description = [{}];
177+
}
178+
179+
180+
def GPUBlockOp : EnzymeXLA_Op<"gpu_block", [
181+
RecursiveMemoryEffects,
182+
SingleBlockImplicitTerminator<"enzymexla::PolygeistYieldOp">]>,
183+
Arguments<(ins Index:$blockIndexX, Index:$blockIndexY, Index:$blockIndexZ)> {
184+
let summary = "Wraps a GPU kernel block to prevent restructuring";
185+
let regions = (region SizedRegion<1>:$region);
186+
let skipDefaultBuilders = 1;
187+
let builders = [OpBuilder<(ins
188+
"Value":$blockIndexX, "Value":$blockIndexY, "Value":$blockIndexZ)>];
189+
}
190+
191+
def GPUThreadOp : EnzymeXLA_Op<"gpu_thread", [
192+
RecursiveMemoryEffects,
193+
SingleBlockImplicitTerminator<"enzymexla::PolygeistYieldOp">]>,
194+
Arguments<(ins Index:$threadIndexX, Index:$threadIndexY, Index:$threadIndexZ)> {
195+
let summary = "Wraps a GPU kernel thread to prevent restructuring";
196+
let regions = (region SizedRegion<1>:$region);
197+
let skipDefaultBuilders = 1;
198+
let builders = [OpBuilder<(ins
199+
"Value":$threadIndexX, "Value":$threadIndexY, "Value":$threadIndexZ)>];
200+
}
201+
202+
def BarrierOp
203+
: EnzymeXLA_Op<"barrier",
204+
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
205+
206+
let arguments = (ins Variadic<Index>:$indices);
207+
let summary = "barrier for parallel loops";
208+
let description = [{}];
209+
let hasCanonicalizer = true;
210+
}
211+
212+
def PolygeistYieldOp : EnzymeXLA_Op<"polygeist_yield", [Pure, ReturnLike, Terminator]> {
213+
//ParentOneOf<["AlternativesOp", "GPUWrapperOp", "GPUErrorOp", "GPUBlockOp", "GPUThreadOp"]>]> {
214+
let summary = "Polygeist ops terminator";
215+
}
216+
217+
def StreamToTokenOp : EnzymeXLA_Op<"stream2token", [
218+
Pure
219+
]> {
220+
let summary = "Extract an async stream from a cuda stream";
221+
222+
let arguments = (ins AnyType : $source);
223+
let results = (outs AnyType : $result);
224+
}
225+
95226
def Memref2PointerOp : EnzymeXLA_Op<"memref2pointer", [
96227
ViewLikeOpInterface, Pure
97228
]> {

0 commit comments

Comments
 (0)