@@ -34,6 +34,44 @@ function scope(
34
34
)
35
35
end
36
36
37
+ function alternatives (; regions:: Vector{Region} , location= Location ())
38
+ op_ty_results = IR. Type[]
39
+ operands = Value[]
40
+ owned_regions = Region[regions... ,]
41
+ successors = Block[]
42
+ attributes = NamedAttribute[]
43
+
44
+ return create_operation (
45
+ " enzymexla.alternatives" ,
46
+ location;
47
+ operands,
48
+ owned_regions,
49
+ successors,
50
+ attributes,
51
+ results= op_ty_results,
52
+ result_inference= false ,
53
+ )
54
+ end
55
+
56
+ function barrier (indices:: Vector{Value} ; location= Location ())
57
+ op_ty_results = IR. Type[]
58
+ operands = Value[indices... ,]
59
+ owned_regions = Region[]
60
+ successors = Block[]
61
+ attributes = NamedAttribute[]
62
+
63
+ return create_operation (
64
+ " enzymexla.barrier" ,
65
+ location;
66
+ operands,
67
+ owned_regions,
68
+ successors,
69
+ attributes,
70
+ results= op_ty_results,
71
+ result_inference= false ,
72
+ )
73
+ end
74
+
37
75
function comm_region (; result_0:: Vector{IR.Type} , body:: Region , location= Location ())
38
76
op_ty_results = IR. Type[result_0... ,]
39
77
operands = Value[]
@@ -84,6 +122,103 @@ function extend(
84
122
)
85
123
end
86
124
125
+ function gpu_block (
126
+ blockIndexX:: Value ,
127
+ blockIndexY:: Value ,
128
+ blockIndexZ:: Value ;
129
+ region:: Region ,
130
+ location= Location (),
131
+ )
132
+ op_ty_results = IR. Type[]
133
+ operands = Value[blockIndexX, blockIndexY, blockIndexZ]
134
+ owned_regions = Region[region,]
135
+ successors = Block[]
136
+ attributes = NamedAttribute[]
137
+
138
+ return create_operation (
139
+ " enzymexla.gpu_block" ,
140
+ location;
141
+ operands,
142
+ owned_regions,
143
+ successors,
144
+ attributes,
145
+ results= op_ty_results,
146
+ result_inference= false ,
147
+ )
148
+ end
149
+
150
+ function gpu_error (; result:: IR.Type , region:: Region , location= Location ())
151
+ op_ty_results = IR. Type[result,]
152
+ operands = Value[]
153
+ owned_regions = Region[region,]
154
+ successors = Block[]
155
+ attributes = NamedAttribute[]
156
+
157
+ return create_operation (
158
+ " enzymexla.gpu_error" ,
159
+ location;
160
+ operands,
161
+ owned_regions,
162
+ successors,
163
+ attributes,
164
+ results= op_ty_results,
165
+ result_inference= false ,
166
+ )
167
+ end
168
+
169
+ function gpu_thread (
170
+ threadIndexX:: Value ,
171
+ threadIndexY:: Value ,
172
+ threadIndexZ:: Value ;
173
+ region:: Region ,
174
+ location= Location (),
175
+ )
176
+ op_ty_results = IR. Type[]
177
+ operands = Value[threadIndexX, threadIndexY, threadIndexZ]
178
+ owned_regions = Region[region,]
179
+ successors = Block[]
180
+ attributes = NamedAttribute[]
181
+
182
+ return create_operation (
183
+ " enzymexla.gpu_thread" ,
184
+ location;
185
+ operands,
186
+ owned_regions,
187
+ successors,
188
+ attributes,
189
+ results= op_ty_results,
190
+ result_inference= false ,
191
+ )
192
+ end
193
+
194
+ """
195
+ `gpu_wrapper`
196
+
197
+ The optional arguments to this operation are suggestions about what block
198
+ dimensions this gpu kernel should have - usually taken from kernel launch
199
+ params
200
+ """
201
+ function gpu_wrapper (
202
+ blockDims:: Vector{Value} ; result:: IR.Type , region:: Region , location= Location ()
203
+ )
204
+ op_ty_results = IR. Type[result,]
205
+ operands = Value[blockDims... ,]
206
+ owned_regions = Region[region,]
207
+ successors = Block[]
208
+ attributes = NamedAttribute[]
209
+
210
+ return create_operation (
211
+ " enzymexla.gpu_wrapper" ,
212
+ location;
213
+ operands,
214
+ owned_regions,
215
+ successors,
216
+ attributes,
217
+ results= op_ty_results,
218
+ result_inference= false ,
219
+ )
220
+ end
221
+
87
222
function get_stream (; result:: IR.Type , location= Location ())
88
223
op_ty_results = IR. Type[result,]
89
224
operands = Value[]
@@ -214,6 +349,51 @@ function linalg_lu(
214
349
)
215
350
end
216
351
352
+ """
353
+ `memcpy`
354
+
355
+ The `gpu.memcpy` operation copies the content of one memref to another.
356
+
357
+ The op does not execute before all async dependencies have finished
358
+ executing.
359
+
360
+ If the `async` keyword is present, the op is executed asynchronously (i.e.
361
+ it does not block until the execution has finished on the device). In
362
+ that case, it returns a !gpu.async.token.
363
+
364
+ # Example
365
+
366
+ ```mlir
367
+ %token = gpu.memcpy async [%dep] %dst, %src : memref<?xf32, 1>, memref<?xf32>
368
+ ```
369
+ """
370
+ function memcpy (
371
+ asyncDependencies:: Vector{Value} ,
372
+ target:: Value ,
373
+ source:: Value ,
374
+ size:: Value ;
375
+ asyncToken= nothing :: Union{Nothing,IR.Type} ,
376
+ location= Location (),
377
+ )
378
+ op_ty_results = IR. Type[]
379
+ operands = Value[asyncDependencies... , target, source, size]
380
+ owned_regions = Region[]
381
+ successors = Block[]
382
+ attributes = NamedAttribute[]
383
+ ! isnothing (asyncToken) && push! (op_ty_results, asyncToken)
384
+
385
+ return create_operation (
386
+ " enzymexla.memcpy" ,
387
+ location;
388
+ operands,
389
+ owned_regions,
390
+ successors,
391
+ attributes,
392
+ results= op_ty_results,
393
+ result_inference= false ,
394
+ )
395
+ end
396
+
217
397
function memref2pointer (source:: Value ; result:: IR.Type , location= Location ())
218
398
op_ty_results = IR. Type[result,]
219
399
operands = Value[source,]
@@ -233,6 +413,25 @@ function memref2pointer(source::Value; result::IR.Type, location=Location())
233
413
)
234
414
end
235
415
416
+ function noop (blockDims:: Vector{Value} ; location= Location ())
417
+ op_ty_results = IR. Type[]
418
+ operands = Value[blockDims... ,]
419
+ owned_regions = Region[]
420
+ successors = Block[]
421
+ attributes = NamedAttribute[]
422
+
423
+ return create_operation (
424
+ " enzymexla.noop" ,
425
+ location;
426
+ operands,
427
+ owned_regions,
428
+ successors,
429
+ attributes,
430
+ results= op_ty_results,
431
+ result_inference= false ,
432
+ )
433
+ end
434
+
236
435
function pointer2memref (source:: Value ; result:: IR.Type , location= Location ())
237
436
op_ty_results = IR. Type[result,]
238
437
operands = Value[source,]
@@ -252,6 +451,46 @@ function pointer2memref(source::Value; result::IR.Type, location=Location())
252
451
)
253
452
end
254
453
454
+ function polygeist_yield (; location= Location ())
455
+ op_ty_results = IR. Type[]
456
+ operands = Value[]
457
+ owned_regions = Region[]
458
+ successors = Block[]
459
+ attributes = NamedAttribute[]
460
+
461
+ return create_operation (
462
+ " enzymexla.polygeist_yield" ,
463
+ location;
464
+ operands,
465
+ owned_regions,
466
+ successors,
467
+ attributes,
468
+ results= op_ty_results,
469
+ result_inference= false ,
470
+ )
471
+ end
472
+
473
+ function linalg_qr (
474
+ input:: Value ; output:: IR.Type , tau:: IR.Type , info:: IR.Type , location= Location ()
475
+ )
476
+ op_ty_results = IR. Type[output, tau, info]
477
+ operands = Value[input,]
478
+ owned_regions = Region[]
479
+ successors = Block[]
480
+ attributes = NamedAttribute[]
481
+
482
+ return create_operation (
483
+ " enzymexla.linalg.qr" ,
484
+ location;
485
+ operands,
486
+ owned_regions,
487
+ successors,
488
+ attributes,
489
+ results= op_ty_results,
490
+ result_inference= false ,
491
+ )
492
+ end
493
+
255
494
function rotate (
256
495
operand:: Value ;
257
496
result= nothing :: Union{Nothing,IR.Type} ,
@@ -280,6 +519,25 @@ function rotate(
280
519
)
281
520
end
282
521
522
+ function stream2token (source:: Value ; result:: IR.Type , location= Location ())
523
+ op_ty_results = IR. Type[result,]
524
+ operands = Value[source,]
525
+ owned_regions = Region[]
526
+ successors = Block[]
527
+ attributes = NamedAttribute[]
528
+
529
+ return create_operation (
530
+ " enzymexla.stream2token" ,
531
+ location;
532
+ operands,
533
+ owned_regions,
534
+ successors,
535
+ attributes,
536
+ results= op_ty_results,
537
+ result_inference= false ,
538
+ )
539
+ end
540
+
283
541
function wrap (
284
542
operand:: Value ;
285
543
result= nothing :: Union{Nothing,IR.Type} ,
@@ -311,4 +569,23 @@ function wrap(
311
569
)
312
570
end
313
571
572
+ function xla_wrapper (inputs:: Vector{Value} ; fn, location= Location ())
573
+ op_ty_results = IR. Type[]
574
+ operands = Value[inputs... ,]
575
+ owned_regions = Region[]
576
+ successors = Block[]
577
+ attributes = NamedAttribute[namedattribute (" fn" , fn),]
578
+
579
+ return create_operation (
580
+ " enzymexla.xla_wrapper" ,
581
+ location;
582
+ operands,
583
+ owned_regions,
584
+ successors,
585
+ attributes,
586
+ results= op_ty_results,
587
+ result_inference= false ,
588
+ )
589
+ end
590
+
314
591
end # enzymexla
0 commit comments