Skip to content

Commit e1bc1fe

Browse files
Regenerate MLIR Bindings (#1543)
Co-authored-by: enzyme-ci-bot[bot] <78882869+enzyme-ci-bot[bot]@users.noreply.github.com>
1 parent 90bf518 commit e1bc1fe

File tree

3 files changed

+318
-18
lines changed

3 files changed

+318
-18
lines changed

src/mlir/Dialects/EnzymeXLA.jl

Lines changed: 188 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,107 @@ function ml_gelu(
244244
)
245245
end
246246

247+
"""
248+
`lapack_gemqrt`
249+
250+
This operation is modeled after LAPACK\'s *GEMQR routines.
251+
"""
252+
function lapack_gemqrt(
253+
V::Value,
254+
T::Value,
255+
C::Value;
256+
output::IR.Type,
257+
side,
258+
transpose=nothing,
259+
location=Location(),
260+
)
261+
op_ty_results = IR.Type[output,]
262+
operands = Value[V, T, C]
263+
owned_regions = Region[]
264+
successors = Block[]
265+
attributes = NamedAttribute[namedattribute("side", side),]
266+
!isnothing(transpose) && push!(attributes, namedattribute("transpose", transpose))
267+
268+
return create_operation(
269+
"enzymexla.lapack.gemqrt",
270+
location;
271+
operands,
272+
owned_regions,
273+
successors,
274+
attributes,
275+
results=op_ty_results,
276+
result_inference=false,
277+
)
278+
end
279+
280+
"""
281+
`lapack_geqrf`
282+
283+
This operation computes the QR factorization of a matrix using Householder
284+
reflections. Mathematically, it decomposes A into the product of an
285+
orthogonal matrix Q and an upper triangular matrix R, such that A = QR.
286+
287+
This operation is modeled after LAPACK\'s *GEQRF routines, which returns the
288+
result in the QR packed format.
289+
"""
290+
function lapack_geqrf(
291+
input::Value; output::IR.Type, tau::IR.Type, info::IR.Type, location=Location()
292+
)
293+
op_ty_results = IR.Type[output, tau, info]
294+
operands = Value[input,]
295+
owned_regions = Region[]
296+
successors = Block[]
297+
attributes = NamedAttribute[]
298+
299+
return create_operation(
300+
"enzymexla.lapack.geqrf",
301+
location;
302+
operands,
303+
owned_regions,
304+
successors,
305+
attributes,
306+
results=op_ty_results,
307+
result_inference=false,
308+
)
309+
end
310+
311+
"""
312+
`lapack_geqrt`
313+
314+
This operation computes the QR factorization of a matrix using Householder
315+
reflections. Mathematically, it decomposes A into the product of an
316+
orthogonal matrix Q and an upper triangular matrix R, such that A = QR.
317+
318+
This operation is modeled after LAPACK\'s *GEQRT routines, which returns the
319+
result in the QR CompactWY format.
320+
"""
321+
function lapack_geqrt(
322+
input::Value;
323+
output::IR.Type,
324+
T::IR.Type,
325+
info::IR.Type,
326+
blocksize=nothing,
327+
location=Location(),
328+
)
329+
op_ty_results = IR.Type[output, T, info]
330+
operands = Value[input,]
331+
owned_regions = Region[]
332+
successors = Block[]
333+
attributes = NamedAttribute[]
334+
!isnothing(blocksize) && push!(attributes, namedattribute("blocksize", blocksize))
335+
336+
return create_operation(
337+
"enzymexla.lapack.geqrt",
338+
location;
339+
operands,
340+
owned_regions,
341+
successors,
342+
attributes,
343+
results=op_ty_results,
344+
result_inference=false,
345+
)
346+
end
347+
247348
function get_stream(; result::IR.Type, location=Location())
248349
op_ty_results = IR.Type[result,]
249350
operands = Value[]
@@ -270,6 +371,8 @@ function jit_call(
270371
backend_config=nothing,
271372
operand_layouts=nothing,
272373
result_layouts=nothing,
374+
arg_attrs=nothing,
375+
res_attrs=nothing,
273376
output_operand_aliases=nothing,
274377
xla_side_effect_free=nothing,
275378
location=Location(),
@@ -285,6 +388,8 @@ function jit_call(
285388
push!(attributes, namedattribute("operand_layouts", operand_layouts))
286389
!isnothing(result_layouts) &&
287390
push!(attributes, namedattribute("result_layouts", result_layouts))
391+
!isnothing(arg_attrs) && push!(attributes, namedattribute("arg_attrs", arg_attrs))
392+
!isnothing(res_attrs) && push!(attributes, namedattribute("res_attrs", res_attrs))
288393
!isnothing(output_operand_aliases) &&
289394
push!(attributes, namedattribute("output_operand_aliases", output_operand_aliases))
290395
!isnothing(xla_side_effect_free) &&
@@ -316,6 +421,8 @@ function kernel_call(
316421
backend_config=nothing,
317422
operand_layouts=nothing,
318423
result_layouts=nothing,
424+
arg_attrs=nothing,
425+
res_attrs=nothing,
319426
output_operand_aliases=nothing,
320427
xla_side_effect_free=nothing,
321428
location=Location(),
@@ -331,6 +438,8 @@ function kernel_call(
331438
push!(attributes, namedattribute("operand_layouts", operand_layouts))
332439
!isnothing(result_layouts) &&
333440
push!(attributes, namedattribute("result_layouts", result_layouts))
441+
!isnothing(arg_attrs) && push!(attributes, namedattribute("arg_attrs", arg_attrs))
442+
!isnothing(res_attrs) && push!(attributes, namedattribute("res_attrs", res_attrs))
334443
!isnothing(output_operand_aliases) &&
335444
push!(attributes, namedattribute("output_operand_aliases", output_operand_aliases))
336445
!isnothing(xla_side_effect_free) &&
@@ -457,6 +566,63 @@ function noop(blockDims::Vector{Value}; location=Location())
457566
)
458567
end
459568

569+
"""
570+
`lapack_orgqr`
571+
572+
This operation is modeled after LAPACK\'s *ORGQR/*UNGQR routines.
573+
"""
574+
function lapack_orgqr(input::Value, tau::Value; output::IR.Type, location=Location())
575+
op_ty_results = IR.Type[output,]
576+
operands = Value[input, tau]
577+
owned_regions = Region[]
578+
successors = Block[]
579+
attributes = NamedAttribute[]
580+
581+
return create_operation(
582+
"enzymexla.lapack.orgqr",
583+
location;
584+
operands,
585+
owned_regions,
586+
successors,
587+
attributes,
588+
results=op_ty_results,
589+
result_inference=false,
590+
)
591+
end
592+
593+
"""
594+
`lapack_ormqr`
595+
596+
This operation is modeled after LAPACK\'s *ORMQR routines.
597+
"""
598+
function lapack_ormqr(
599+
A::Value,
600+
tau::Value,
601+
C::Value;
602+
output::IR.Type,
603+
side,
604+
transpose=nothing,
605+
location=Location(),
606+
)
607+
op_ty_results = IR.Type[output,]
608+
operands = Value[A, tau, C]
609+
owned_regions = Region[]
610+
successors = Block[]
611+
attributes = NamedAttribute[namedattribute("side", side),]
612+
!isnothing(transpose) && push!(attributes, namedattribute("transpose", transpose))
613+
614+
return create_operation(
615+
"enzymexla.lapack.ormqr",
616+
location;
617+
operands,
618+
owned_regions,
619+
successors,
620+
attributes,
621+
results=op_ty_results,
622+
result_inference=false,
623+
)
624+
end
625+
460626
function pointer2memref(source::Value; result::IR.Type, location=Location())
461627
op_ty_results = IR.Type[result,]
462628
operands = Value[source,]
@@ -495,14 +661,29 @@ function polygeist_yield(; location=Location())
495661
)
496662
end
497663

664+
"""
665+
`linalg_qr`
666+
667+
This operation computes the QR factorization of a matrix using Householder
668+
reflections. Mathematically, it decomposes A into the product of an
669+
orthogonal (unitary if complex) matrix Q and an upper triangular matrix R,
670+
such that A = QR.
671+
672+
If A has size m x n and m > n, Q is an m x n isometric matrix. If m < n, R
673+
will be a m x n trapezoidal matrix.
674+
675+
This operation is modeled after the mathematical formulation of the QR
676+
factorization, and not after LAPACK\'s compact formats.
677+
"""
498678
function linalg_qr(
499-
input::Value; output::IR.Type, tau::IR.Type, info::IR.Type, location=Location()
679+
input::Value; Q::IR.Type, R::IR.Type, algorithm=nothing, location=Location()
500680
)
501-
op_ty_results = IR.Type[output, tau, info]
681+
op_ty_results = IR.Type[Q, R]
502682
operands = Value[input,]
503683
owned_regions = Region[]
504684
successors = Block[]
505685
attributes = NamedAttribute[]
686+
!isnothing(algorithm) && push!(attributes, namedattribute("algorithm", algorithm))
506687

507688
return create_operation(
508689
"enzymexla.linalg.qr",
@@ -642,12 +823,16 @@ function wrap(
642823
)
643824
end
644825

645-
function xla_wrapper(inputs::Vector{Value}; fn, location=Location())
826+
function xla_wrapper(
827+
inputs::Vector{Value}; fn, arg_attrs=nothing, res_attrs=nothing, location=Location()
828+
)
646829
op_ty_results = IR.Type[]
647830
operands = Value[inputs...,]
648831
owned_regions = Region[]
649832
successors = Block[]
650833
attributes = NamedAttribute[namedattribute("fn", fn),]
834+
!isnothing(arg_attrs) && push!(attributes, namedattribute("arg_attrs", arg_attrs))
835+
!isnothing(res_attrs) && push!(attributes, namedattribute("res_attrs", res_attrs))
651836

652837
return create_operation(
653838
"enzymexla.xla_wrapper",

src/mlir/Dialects/MosaicGPU.jl

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ end
404404
Schedules `tcgen05.mma` instructions that perform the following matrix
405405
multiply and accumulate:
406406
407-
accumulator = a * b + accumulator
407+
accumulator += a * b
408408
409409
This operation supports larger inputs than the PTX-level MMA instruction
410410
and will schedule as many PTX-level MMA instructions as needed to
@@ -417,8 +417,6 @@ The inputs should have the following shapes:
417417
where `s == swizzle / element_bytewidth` and `m` is specified according to
418418
https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-shape.
419419
420-
The output has an identical shape and type as the input accumulator.
421-
422420
The `accumulator`, `a` and `b` matrices need to be provided as 2-dimensional
423421
memrefs. The `accumulator` is always in TMEM and `b` is always in SMEM.
424422
`a` can be in TMEM or SMEM. `a` and `b` must have the same element
@@ -427,7 +425,7 @@ type and when `a` is in TMEM only F16 or BF16 are supported.
427425
`a_scale` and `b_scale` are optional scaling matrices that reside in TMEM.
428426
When set the operation is defined as:
429427
430-
accumulator = (a * a_scale) * (b * b_scale) + accumulator
428+
accumulator += (a * a_scale) * (b * b_scale)
431429
432430
`accumulate` is a boolean that indicates whether to perform the accumulate
433431
step.
@@ -439,7 +437,6 @@ function tcgen05_mma(
439437
accumulate::Value,
440438
a_scale=nothing::Union{Nothing,Value};
441439
b_scale=nothing::Union{Nothing,Value},
442-
result_0=nothing::Union{Nothing,IR.Type},
443440
collective=nothing,
444441
location=Location(),
445442
)
@@ -463,7 +460,6 @@ function tcgen05_mma(
463460
1
464461
end,
465462
]))
466-
!isnothing(result_0) && push!(op_ty_results, result_0)
467463
!isnothing(collective) && push!(attributes, namedattribute("collective", collective))
468464

469465
return create_operation(
@@ -473,8 +469,8 @@ function tcgen05_mma(
473469
owned_regions,
474470
successors,
475471
attributes,
476-
results=(length(op_ty_results) == 0 ? nothing : op_ty_results),
477-
result_inference=(length(op_ty_results) == 0 ? true : false),
472+
results=op_ty_results,
473+
result_inference=false,
478474
)
479475
end
480476

0 commit comments

Comments
 (0)