Skip to content

Commit 37fbdbd

Browse files
authored
feat: add enzymexla op wrappers (#1453)
1 parent 2fc1fd9 commit 37fbdbd

File tree

1 file changed

+71
-2
lines changed

1 file changed

+71
-2
lines changed

src/Ops.jl

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Julia and Reactant semantics should be considered on the higher abstractions that use these ops.
44
module Ops
55
using ..MLIR: MLIR
6-
using ..MLIR.Dialects: stablehlo, chlo, enzyme
6+
using ..MLIR.Dialects: stablehlo, chlo, enzyme, enzymexla
77
using ..Reactant:
88
Reactant,
99
TracedRArray,
@@ -3003,7 +3003,7 @@ Compute the row maximum pivoted LU factorization of `x` and return the factors `
30033003
permutation_shape = vcat(batch_shape, size(x, ndims(x) - 1))
30043004
info_shape = batch_shape
30053005

3006-
op = MLIR.Dialects.enzymexla.linalg_lu(
3006+
op = enzymexla.linalg_lu(
30073007
x.mlir_data;
30083008
output=MLIR.IR.TensorType(output_shape, MLIR.IR.Type(unwrapped_eltype(T))),
30093009
pivots=MLIR.IR.TensorType(pivots_shape, MLIR.IR.Type(pT)),
@@ -3210,4 +3210,73 @@ end
32103210
end
32113211
end
32123212

3213+
@noinline function wrap(
3214+
input::TracedRArray{T,N},
3215+
lhs::Integer,
3216+
rhs::Integer;
3217+
dimension::Int,
3218+
location=mlir_stacktrace("wrap", @__FILE__, @__LINE__),
3219+
) where {T,N}
3220+
@assert 1 dimension N "dimension must be between 1 and $(N) (got $(dimension))"
3221+
@assert 0 lhs size(input, dimension) "lhs must be between 0 and \
3222+
$(size(input, dimension)) (got $(lhs))"
3223+
@assert 0 rhs size(input, dimension) "rhs must be between 0 and \
3224+
$(size(input, dimension)) (got $(rhs))"
3225+
return TracedRArray{T,N}(
3226+
(),
3227+
MLIR.IR.result(
3228+
enzymexla.wrap(input.mlir_data; lhs, rhs, dimension=dimension - 1, location), 1
3229+
),
3230+
size(input),
3231+
)
3232+
end
3233+
3234+
@noinline function extend(
3235+
input::TracedRArray{T,N},
3236+
lhs::Integer,
3237+
rhs::Integer;
3238+
dimension::Int,
3239+
location=mlir_stacktrace("extend", @__FILE__, @__LINE__),
3240+
) where {T,N}
3241+
@assert 1 dimension N "dimension must be between 1 and $(N) (got $(dimension))"
3242+
@assert 0 lhs size(input, dimension) "lhs must be between 0 and \
3243+
$(size(input, dimension)) (got $(lhs))"
3244+
@assert 0 rhs size(input, dimension) "rhs must be between 0 and \
3245+
$(size(input, dimension)) (got $(rhs))"
3246+
sz = collect(Int64, size(input))
3247+
sz[dimension] = sz[dimension] + lhs + rhs
3248+
return TracedRArray{T,N}(
3249+
(),
3250+
MLIR.IR.result(
3251+
enzymexla.extend(input.mlir_data; lhs, rhs, dimension=dimension - 1, location),
3252+
1,
3253+
),
3254+
sz,
3255+
)
3256+
end
3257+
3258+
@noinline function rotate(
3259+
input::TracedRArray{T,N},
3260+
amount::Integer;
3261+
dimension::Int,
3262+
location=mlir_stacktrace("rotate", @__FILE__, @__LINE__),
3263+
) where {T,N}
3264+
@assert 1 dimension N "dimension must be between 1 and $(N) (got $(dimension))"
3265+
@assert 0 amount size(input, dimension) "amount must be between 0 and \
3266+
$(size(input, dimension)) (got $(amount))"
3267+
return TracedRArray{T,N}(
3268+
(),
3269+
MLIR.IR.result(
3270+
enzymexla.rotate(
3271+
input.mlir_data;
3272+
amount=Int32(amount),
3273+
dimension=Int32(dimension - 1),
3274+
location,
3275+
),
3276+
1,
3277+
),
3278+
size(input),
3279+
)
3280+
end
3281+
32133282
end # module Ops

0 commit comments

Comments
 (0)