@@ -244,6 +244,107 @@ function ml_gelu(
244
244
)
245
245
end
246
246
247
+ """
248
+ `lapack_gemqrt`
249
+
250
+ This operation is modeled after LAPACK\' s *GEMQR routines.
251
+ """
252
+ function lapack_gemqrt (
253
+ V:: Value ,
254
+ T:: Value ,
255
+ C:: Value ;
256
+ output:: IR.Type ,
257
+ side,
258
+ transpose= nothing ,
259
+ location= Location (),
260
+ )
261
+ op_ty_results = IR. Type[output,]
262
+ operands = Value[V, T, C]
263
+ owned_regions = Region[]
264
+ successors = Block[]
265
+ attributes = NamedAttribute[namedattribute (" side" , side),]
266
+ ! isnothing (transpose) && push! (attributes, namedattribute (" transpose" , transpose))
267
+
268
+ return create_operation (
269
+ " enzymexla.lapack.gemqrt" ,
270
+ location;
271
+ operands,
272
+ owned_regions,
273
+ successors,
274
+ attributes,
275
+ results= op_ty_results,
276
+ result_inference= false ,
277
+ )
278
+ end
279
+
280
+ """
281
+ `lapack_geqrf`
282
+
283
+ This operation computes the QR factorization of a matrix using Householder
284
+ reflections. Mathematically, it decomposes A into the product of an
285
+ orthogonal matrix Q and an upper triangular matrix R, such that A = QR.
286
+
287
+ This operation is modeled after LAPACK\' s *GEQRF routines, which returns the
288
+ result in the QR packed format.
289
+ """
290
+ function lapack_geqrf (
291
+ input:: Value ; output:: IR.Type , tau:: IR.Type , info:: IR.Type , location= Location ()
292
+ )
293
+ op_ty_results = IR. Type[output, tau, info]
294
+ operands = Value[input,]
295
+ owned_regions = Region[]
296
+ successors = Block[]
297
+ attributes = NamedAttribute[]
298
+
299
+ return create_operation (
300
+ " enzymexla.lapack.geqrf" ,
301
+ location;
302
+ operands,
303
+ owned_regions,
304
+ successors,
305
+ attributes,
306
+ results= op_ty_results,
307
+ result_inference= false ,
308
+ )
309
+ end
310
+
311
+ """
312
+ `lapack_geqrt`
313
+
314
+ This operation computes the QR factorization of a matrix using Householder
315
+ reflections. Mathematically, it decomposes A into the product of an
316
+ orthogonal matrix Q and an upper triangular matrix R, such that A = QR.
317
+
318
+ This operation is modeled after LAPACK\' s *GEQRT routines, which returns the
319
+ result in the QR CompactWY format.
320
+ """
321
+ function lapack_geqrt (
322
+ input:: Value ;
323
+ output:: IR.Type ,
324
+ T:: IR.Type ,
325
+ info:: IR.Type ,
326
+ blocksize= nothing ,
327
+ location= Location (),
328
+ )
329
+ op_ty_results = IR. Type[output, T, info]
330
+ operands = Value[input,]
331
+ owned_regions = Region[]
332
+ successors = Block[]
333
+ attributes = NamedAttribute[]
334
+ ! isnothing (blocksize) && push! (attributes, namedattribute (" blocksize" , blocksize))
335
+
336
+ return create_operation (
337
+ " enzymexla.lapack.geqrt" ,
338
+ location;
339
+ operands,
340
+ owned_regions,
341
+ successors,
342
+ attributes,
343
+ results= op_ty_results,
344
+ result_inference= false ,
345
+ )
346
+ end
347
+
247
348
function get_stream (; result:: IR.Type , location= Location ())
248
349
op_ty_results = IR. Type[result,]
249
350
operands = Value[]
@@ -270,6 +371,8 @@ function jit_call(
270
371
backend_config= nothing ,
271
372
operand_layouts= nothing ,
272
373
result_layouts= nothing ,
374
+ arg_attrs= nothing ,
375
+ res_attrs= nothing ,
273
376
output_operand_aliases= nothing ,
274
377
xla_side_effect_free= nothing ,
275
378
location= Location (),
@@ -285,6 +388,8 @@ function jit_call(
285
388
push! (attributes, namedattribute (" operand_layouts" , operand_layouts))
286
389
! isnothing (result_layouts) &&
287
390
push! (attributes, namedattribute (" result_layouts" , result_layouts))
391
+ ! isnothing (arg_attrs) && push! (attributes, namedattribute (" arg_attrs" , arg_attrs))
392
+ ! isnothing (res_attrs) && push! (attributes, namedattribute (" res_attrs" , res_attrs))
288
393
! isnothing (output_operand_aliases) &&
289
394
push! (attributes, namedattribute (" output_operand_aliases" , output_operand_aliases))
290
395
! isnothing (xla_side_effect_free) &&
@@ -316,6 +421,8 @@ function kernel_call(
316
421
backend_config= nothing ,
317
422
operand_layouts= nothing ,
318
423
result_layouts= nothing ,
424
+ arg_attrs= nothing ,
425
+ res_attrs= nothing ,
319
426
output_operand_aliases= nothing ,
320
427
xla_side_effect_free= nothing ,
321
428
location= Location (),
@@ -331,6 +438,8 @@ function kernel_call(
331
438
push! (attributes, namedattribute (" operand_layouts" , operand_layouts))
332
439
! isnothing (result_layouts) &&
333
440
push! (attributes, namedattribute (" result_layouts" , result_layouts))
441
+ ! isnothing (arg_attrs) && push! (attributes, namedattribute (" arg_attrs" , arg_attrs))
442
+ ! isnothing (res_attrs) && push! (attributes, namedattribute (" res_attrs" , res_attrs))
334
443
! isnothing (output_operand_aliases) &&
335
444
push! (attributes, namedattribute (" output_operand_aliases" , output_operand_aliases))
336
445
! isnothing (xla_side_effect_free) &&
@@ -457,6 +566,63 @@ function noop(blockDims::Vector{Value}; location=Location())
457
566
)
458
567
end
459
568
569
+ """
570
+ `lapack_orgqr`
571
+
572
+ This operation is modeled after LAPACK\' s *ORGQR/*UNGQR routines.
573
+ """
574
+ function lapack_orgqr (input:: Value , tau:: Value ; output:: IR.Type , location= Location ())
575
+ op_ty_results = IR. Type[output,]
576
+ operands = Value[input, tau]
577
+ owned_regions = Region[]
578
+ successors = Block[]
579
+ attributes = NamedAttribute[]
580
+
581
+ return create_operation (
582
+ " enzymexla.lapack.orgqr" ,
583
+ location;
584
+ operands,
585
+ owned_regions,
586
+ successors,
587
+ attributes,
588
+ results= op_ty_results,
589
+ result_inference= false ,
590
+ )
591
+ end
592
+
593
+ """
594
+ `lapack_ormqr`
595
+
596
+ This operation is modeled after LAPACK\' s *ORMQR routines.
597
+ """
598
+ function lapack_ormqr (
599
+ A:: Value ,
600
+ tau:: Value ,
601
+ C:: Value ;
602
+ output:: IR.Type ,
603
+ side,
604
+ transpose= nothing ,
605
+ location= Location (),
606
+ )
607
+ op_ty_results = IR. Type[output,]
608
+ operands = Value[A, tau, C]
609
+ owned_regions = Region[]
610
+ successors = Block[]
611
+ attributes = NamedAttribute[namedattribute (" side" , side),]
612
+ ! isnothing (transpose) && push! (attributes, namedattribute (" transpose" , transpose))
613
+
614
+ return create_operation (
615
+ " enzymexla.lapack.ormqr" ,
616
+ location;
617
+ operands,
618
+ owned_regions,
619
+ successors,
620
+ attributes,
621
+ results= op_ty_results,
622
+ result_inference= false ,
623
+ )
624
+ end
625
+
460
626
function pointer2memref (source:: Value ; result:: IR.Type , location= Location ())
461
627
op_ty_results = IR. Type[result,]
462
628
operands = Value[source,]
@@ -495,14 +661,29 @@ function polygeist_yield(; location=Location())
495
661
)
496
662
end
497
663
664
+ """
665
+ `linalg_qr`
666
+
667
+ This operation computes the QR factorization of a matrix using Householder
668
+ reflections. Mathematically, it decomposes A into the product of an
669
+ orthogonal (unitary if complex) matrix Q and an upper triangular matrix R,
670
+ such that A = QR.
671
+
672
+ If A has size m x n and m > n, Q is an m x n isometric matrix. If m < n, R
673
+ will be a m x n trapezoidal matrix.
674
+
675
+ This operation is modeled after the mathematical formulation of the QR
676
+ factorization, and not after LAPACK\' s compact formats.
677
+ """
498
678
function linalg_qr (
499
- input:: Value ; output :: IR.Type , tau :: IR.Type , info :: IR.Type , location= Location ()
679
+ input:: Value ; Q :: IR.Type , R :: IR.Type , algorithm = nothing , location= Location ()
500
680
)
501
- op_ty_results = IR. Type[output, tau, info ]
681
+ op_ty_results = IR. Type[Q, R ]
502
682
operands = Value[input,]
503
683
owned_regions = Region[]
504
684
successors = Block[]
505
685
attributes = NamedAttribute[]
686
+ ! isnothing (algorithm) && push! (attributes, namedattribute (" algorithm" , algorithm))
506
687
507
688
return create_operation (
508
689
" enzymexla.linalg.qr" ,
@@ -642,12 +823,16 @@ function wrap(
642
823
)
643
824
end
644
825
645
- function xla_wrapper (inputs:: Vector{Value} ; fn, location= Location ())
826
+ function xla_wrapper (
827
+ inputs:: Vector{Value} ; fn, arg_attrs= nothing , res_attrs= nothing , location= Location ()
828
+ )
646
829
op_ty_results = IR. Type[]
647
830
operands = Value[inputs... ,]
648
831
owned_regions = Region[]
649
832
successors = Block[]
650
833
attributes = NamedAttribute[namedattribute (" fn" , fn),]
834
+ ! isnothing (arg_attrs) && push! (attributes, namedattribute (" arg_attrs" , arg_attrs))
835
+ ! isnothing (res_attrs) && push! (attributes, namedattribute (" res_attrs" , res_attrs))
651
836
652
837
return create_operation (
653
838
" enzymexla.xla_wrapper" ,
0 commit comments