|
3 | 3 | # Julia and Reactant semantics should be considered on the higher abstractions that use these ops.
|
4 | 4 | module Ops
|
5 | 5 | using ..MLIR: MLIR
|
6 |
| -using ..MLIR.Dialects: stablehlo, chlo, enzyme |
| 6 | +using ..MLIR.Dialects: stablehlo, chlo, enzyme, enzymexla |
7 | 7 | using ..Reactant:
|
8 | 8 | Reactant,
|
9 | 9 | TracedRArray,
|
@@ -3003,7 +3003,7 @@ Compute the row maximum pivoted LU factorization of `x` and return the factors `
|
3003 | 3003 | permutation_shape = vcat(batch_shape, size(x, ndims(x) - 1))
|
3004 | 3004 | info_shape = batch_shape
|
3005 | 3005 |
|
3006 |
| - op = MLIR.Dialects.enzymexla.linalg_lu( |
| 3006 | + op = enzymexla.linalg_lu( |
3007 | 3007 | x.mlir_data;
|
3008 | 3008 | output=MLIR.IR.TensorType(output_shape, MLIR.IR.Type(unwrapped_eltype(T))),
|
3009 | 3009 | pivots=MLIR.IR.TensorType(pivots_shape, MLIR.IR.Type(pT)),
|
@@ -3210,4 +3210,73 @@ end
|
3210 | 3210 | end
|
3211 | 3211 | end
|
3212 | 3212 |
|
| 3213 | +@noinline function wrap( |
| 3214 | + input::TracedRArray{T,N}, |
| 3215 | + lhs::Integer, |
| 3216 | + rhs::Integer; |
| 3217 | + dimension::Int, |
| 3218 | + location=mlir_stacktrace("wrap", @__FILE__, @__LINE__), |
| 3219 | +) where {T,N} |
| 3220 | + @assert 1 ≤ dimension ≤ N "dimension must be between 1 and $(N) (got $(dimension))" |
| 3221 | + @assert 0 ≤ lhs ≤ size(input, dimension) "lhs must be between 0 and \ |
| 3222 | + $(size(input, dimension)) (got $(lhs))" |
| 3223 | + @assert 0 ≤ rhs ≤ size(input, dimension) "rhs must be between 0 and \ |
| 3224 | + $(size(input, dimension)) (got $(rhs))" |
| 3225 | + return TracedRArray{T,N}( |
| 3226 | + (), |
| 3227 | + MLIR.IR.result( |
| 3228 | + enzymexla.wrap(input.mlir_data; lhs, rhs, dimension=dimension - 1, location), 1 |
| 3229 | + ), |
| 3230 | + size(input), |
| 3231 | + ) |
| 3232 | +end |
| 3233 | + |
| 3234 | +@noinline function extend( |
| 3235 | + input::TracedRArray{T,N}, |
| 3236 | + lhs::Integer, |
| 3237 | + rhs::Integer; |
| 3238 | + dimension::Int, |
| 3239 | + location=mlir_stacktrace("extend", @__FILE__, @__LINE__), |
| 3240 | +) where {T,N} |
| 3241 | + @assert 1 ≤ dimension ≤ N "dimension must be between 1 and $(N) (got $(dimension))" |
| 3242 | + @assert 0 ≤ lhs ≤ size(input, dimension) "lhs must be between 0 and \ |
| 3243 | + $(size(input, dimension)) (got $(lhs))" |
| 3244 | + @assert 0 ≤ rhs ≤ size(input, dimension) "rhs must be between 0 and \ |
| 3245 | + $(size(input, dimension)) (got $(rhs))" |
| 3246 | + sz = collect(Int64, size(input)) |
| 3247 | + sz[dimension] = sz[dimension] + lhs + rhs |
| 3248 | + return TracedRArray{T,N}( |
| 3249 | + (), |
| 3250 | + MLIR.IR.result( |
| 3251 | + enzymexla.extend(input.mlir_data; lhs, rhs, dimension=dimension - 1, location), |
| 3252 | + 1, |
| 3253 | + ), |
| 3254 | + sz, |
| 3255 | + ) |
| 3256 | +end |
| 3257 | + |
| 3258 | +@noinline function rotate( |
| 3259 | + input::TracedRArray{T,N}, |
| 3260 | + amount::Integer; |
| 3261 | + dimension::Int, |
| 3262 | + location=mlir_stacktrace("rotate", @__FILE__, @__LINE__), |
| 3263 | +) where {T,N} |
| 3264 | + @assert 1 ≤ dimension ≤ N "dimension must be between 1 and $(N) (got $(dimension))" |
| 3265 | + @assert 0 ≤ amount ≤ size(input, dimension) "amount must be between 0 and \ |
| 3266 | + $(size(input, dimension)) (got $(amount))" |
| 3267 | + return TracedRArray{T,N}( |
| 3268 | + (), |
| 3269 | + MLIR.IR.result( |
| 3270 | + enzymexla.rotate( |
| 3271 | + input.mlir_data; |
| 3272 | + amount=Int32(amount), |
| 3273 | + dimension=Int32(dimension - 1), |
| 3274 | + location, |
| 3275 | + ), |
| 3276 | + 1, |
| 3277 | + ), |
| 3278 | + size(input), |
| 3279 | + ) |
| 3280 | +end |
| 3281 | + |
3213 | 3282 | end # module Ops
|
0 commit comments