1212include "Enzyme/MLIR/Dialect/Dialect.td"
1313include "Dialect.td"
1414
15+
16+ include "mlir/Interfaces/CopyOpInterface.td"
1517include "mlir/Interfaces/ViewLikeInterface.td"
1618include "mlir/IR/SymbolInterfaces.td"
1719include "mlir/IR/EnumAttr.td"
@@ -26,6 +28,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
2628include "mlir/Interfaces/CallInterfaces.td"
2729include "mlir/Interfaces/InferTypeOpInterface.td"
2830include "stablehlo/dialect/Base.td"
31+ include "mlir/Dialect/GPU/IR/GPUBase.td"
2932
3033def TensorI64 : Type<CPred<"::llvm::isa<::mlir::TensorType>($_self) && ::llvm::cast<::mlir::TensorType>($_self).getShape().size() == 0 && ::llvm::cast<::mlir::TensorType>($_self).getElementType().isSignlessInteger(64)">, "tensor<i64>",
3134 "::mlir::TensorType">,
@@ -62,6 +65,43 @@ def KernelCallOp: EnzymeXLA_Op<"kernel_call", [DeclareOpInterfaceMethods<SymbolU
6265 let hasCanonicalizer = 1;
6366}
6467
68+ def MemcpyOp : EnzymeXLA_Op<"memcpy", [CopyOpInterface]> {
69+
70+ let summary = "GPU memcpy operation";
71+
72+ let description = [{
73+ The `gpu.memcpy` operation copies the content of one memref to another.
74+
75+ The op does not execute before all async dependencies have finished
76+ executing.
77+
78+ If the `async` keyword is present, the op is executed asynchronously (i.e.
79+ it does not block until the execution has finished on the device). In
80+ that case, it returns a !gpu.async.token.
81+
82+ Example:
83+
84+ ```mlir
85+ %token = gpu.memcpy async [%dep] %dst, %src : memref<?xf32, 1>, memref<?xf32>
86+ ```
87+ }];
88+
89+ let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
90+ Arg<AnyMemRef, "", [MemWriteAt<0, FullEffect>]>:$target,
91+ Arg<AnyMemRef, "", [MemReadAt<0, FullEffect>]>:$source,
92+ Index:$size
93+ );
94+ let results = (outs Optional<GPU_AsyncToken>:$asyncToken);
95+
96+ let assemblyFormat = [{
97+ custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
98+ $target`,` $source `,` $size `:` type($target)`,` type($source) attr-dict
99+ }];
100+ let hasFolder = 1;
101+ let hasVerifier = 1;
102+ let hasCanonicalizer = 1;
103+ }
104+
65105def JITCallOp: EnzymeXLA_Op<"jit_call", [DeclareOpInterfaceMethods<SymbolUserOpInterface>, DeclareOpInterfaceMethods<CallOpInterface>, DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
66106 let summary = "JIT Call operation";
67107
@@ -92,6 +132,97 @@ def GetStreamOp : EnzymeXLA_Op<"get_stream", [Pure]> {
92132 let results = (outs AnyType:$result);
93133}
94134
135+
136+ def GPUWrapperOp : EnzymeXLA_Op<"gpu_wrapper", [
137+ RecursiveMemoryEffects,
138+ AutomaticAllocationScope,
139+ SingleBlockImplicitTerminator<"enzymexla::PolygeistYieldOp">]> {
140+ let arguments = (ins Variadic<Index>:$blockDims);
141+ let summary = "Indicates the region contained must be executed on the GPU";
142+ let description = [{
143+ The optional arguments to this operation are suggestions about what block
144+ dimensions this gpu kernel should have - usually taken from kernel launch
145+ params
146+ }];
147+ let results = (outs Index : $result);
148+ let regions = (region SizedRegion<1>:$region);
149+ let skipDefaultBuilders = 1;
150+ let builders = [
151+ OpBuilder<(ins "ValueRange":$blockSizes)>,
152+ OpBuilder<(ins)>];
153+ }
154+
155+ def GPUErrorOp : EnzymeXLA_Op<"gpu_error", [
156+ RecursiveMemoryEffects,
157+ SingleBlockImplicitTerminator<"enzymexla::PolygeistYieldOp">]>,
158+ Arguments<(ins)> {
159+ let summary = "Gets the error returned by the gpu operation inside";
160+ // TODO should be i32, not index
161+ let results = (outs Index : $result);
162+ let regions = (region SizedRegion<1>:$region);
163+ let skipDefaultBuilders = 1;
164+ let builders = [OpBuilder<(ins)>];
165+
166+ }
167+
168+ def NoopOp
169+ : EnzymeXLA_Op<"noop",
170+ [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
171+ let summary = "Noop for preventing folding or transformations";
172+ let arguments = (ins Variadic<Index>:$blockDims);
173+ let skipDefaultBuilders = 1;
174+ let builders = [
175+ OpBuilder<(ins "ValueRange":$indices)>];
176+ let description = [{}];
177+ }
178+
179+
180+ def GPUBlockOp : EnzymeXLA_Op<"gpu_block", [
181+ RecursiveMemoryEffects,
182+ SingleBlockImplicitTerminator<"enzymexla::PolygeistYieldOp">]>,
183+ Arguments<(ins Index:$blockIndexX, Index:$blockIndexY, Index:$blockIndexZ)> {
184+ let summary = "Wraps a GPU kernel block to prevent restructuring";
185+ let regions = (region SizedRegion<1>:$region);
186+ let skipDefaultBuilders = 1;
187+ let builders = [OpBuilder<(ins
188+ "Value":$blockIndexX, "Value":$blockIndexY, "Value":$blockIndexZ)>];
189+ }
190+
191+ def GPUThreadOp : EnzymeXLA_Op<"gpu_thread", [
192+ RecursiveMemoryEffects,
193+ SingleBlockImplicitTerminator<"enzymexla::PolygeistYieldOp">]>,
194+ Arguments<(ins Index:$threadIndexX, Index:$threadIndexY, Index:$threadIndexZ)> {
195+ let summary = "Wraps a GPU kernel thread to prevent restructuring";
196+ let regions = (region SizedRegion<1>:$region);
197+ let skipDefaultBuilders = 1;
198+ let builders = [OpBuilder<(ins
199+ "Value":$threadIndexX, "Value":$threadIndexY, "Value":$threadIndexZ)>];
200+ }
201+
202+ def BarrierOp
203+ : EnzymeXLA_Op<"barrier",
204+ [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
205+
206+ let arguments = (ins Variadic<Index>:$indices);
207+ let summary = "barrier for parallel loops";
208+ let description = [{}];
209+ let hasCanonicalizer = true;
210+ }
211+
212+ def PolygeistYieldOp : EnzymeXLA_Op<"polygeist_yield", [Pure, ReturnLike, Terminator]> {
213+ //ParentOneOf<["AlternativesOp", "GPUWrapperOp", "GPUErrorOp", "GPUBlockOp", "GPUThreadOp"]>]> {
214+ let summary = "Polygeist ops terminator";
215+ }
216+
217+ def StreamToTokenOp : EnzymeXLA_Op<"stream2token", [
218+ Pure
219+ ]> {
220+ let summary = "Extract an async stream from a cuda stream";
221+
222+ let arguments = (ins AnyType : $source);
223+ let results = (outs AnyType : $result);
224+ }
225+
95226def Memref2PointerOp : EnzymeXLA_Op<"memref2pointer", [
96227 ViewLikeOpInterface, Pure
97228]> {
0 commit comments