Skip to content

Comments

integrate with OMEinsum#2222

Open
mofeing wants to merge 18 commits intomainfrom
ss/omeinsum-ext
Open

integrate with OMEinsum#2222
mofeing wants to merge 18 commits intomainfrom
ss/omeinsum-ext

Conversation

@mofeing
Copy link
Collaborator

@mofeing mofeing commented Jan 26, 2026

to do


@oschulz this should help with #1586

@mofeing
Copy link
Collaborator Author

mofeing commented Jan 30, 2026

hey @GiggleLiu, I'm adding support to Reactant for OMEinsum and I was wondering if you miss any special case in the tests

@mofeing mofeing marked this pull request as ready for review January 30, 2026 14:09
@mofeing
Copy link
Collaborator Author

mofeing commented Jan 30, 2026

@avik-pal @wsmoses I think there is a bug in ReduceSliceFusion pass, as it segfaults in CI and gets stuck on local

MWE

the MWE is computing the trace of a 2x2 matrix

a = rand(2,2)
are = Reactant.to_rarray(a)
@code_hlo ein"ii->"(are)

note that extracting the diagonal does work (i.e. ein"ii->i")

MLIR code before optimizations

julia> @code_hlo optimize=false ein"ii->"(are)
module @"reactant_ii -> " attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
  func.func private @identity_broadcast_scalar(%arg0: tensor<f64> {enzymexla.memory_effects = []}) -> tensor<f64> attributes {enzymexla.memory_effects = []} {
    return %arg0 : tensor<f64>
  }
  func.func private @"*_broadcast_scalar"(%arg0: tensor<f64> {enzymexla.memory_effects = []}) -> (tensor<f64>, tensor<f64>) attributes {enzymexla.memory_effects = []} {
    %c = stablehlo.constant dense<true> : tensor<i1>
    %0 = stablehlo.convert %c : (tensor<i1>) -> tensor<f64>
    %1 = stablehlo.multiply %0, %arg0 : tensor<f64>
    return %1, %arg0 : tensor<f64>, tensor<f64>
  }
  func.func @main(%arg0: tensor<2x2xf64> {enzymexla.memory_effects = ["read", "write", "allocate", "free"], tf.aliasing_output = 1 : i32}) -> (tensor<f64>, tensor<2x2xf64>) attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} {
    %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<2x2xf64>) -> tensor<2x2xf64>
    %c = stablehlo.constant dense<0> : tensor<i64>
    %1 = stablehlo.convert %c : (tensor<i64>) -> tensor<f64>
    %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor<f64>) -> tensor<f64>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<2xf64>
    %c_0 = stablehlo.constant dense<1> : tensor<i64>
    %3 = stablehlo.convert %c_0 : (tensor<i64>) -> tensor<i32>
    %c_1 = stablehlo.constant dense<1> : tensor<i32>
    %4 = stablehlo.convert %c_1 : tensor<i32>
    %5 = stablehlo.subtract %3, %4 : tensor<i32>
    %c_2 = stablehlo.constant dense<1> : tensor<i64>
    %6 = stablehlo.convert %c_2 : (tensor<i64>) -> tensor<i32>
    %c_3 = stablehlo.constant dense<1> : tensor<i32>
    %7 = stablehlo.convert %c_3 : tensor<i32>
    %8 = stablehlo.subtract %6, %7 : tensor<i32>
    %9 = stablehlo.dynamic_slice %0, %5, %8, sizes = [1, 1] : (tensor<2x2xf64>, tensor<i32>, tensor<i32>) -> tensor<1x1xf64>
    %10 = stablehlo.transpose %9, dims = [1, 0] : (tensor<1x1xf64>) -> tensor<1x1xf64>
    %11 = stablehlo.reshape %10 : (tensor<1x1xf64>) -> tensor<f64>
    %12 = stablehlo.transpose %11, dims = [] : (tensor<f64>) -> tensor<f64>
    %13 = stablehlo.broadcast_in_dim %12, dims = [] : (tensor<f64>) -> tensor<1xf64>
    %c_4 = stablehlo.constant dense<1> : tensor<i64>
    %14 = stablehlo.convert %c_4 : (tensor<i64>) -> tensor<i32>
    %c_5 = stablehlo.constant dense<1> : tensor<i32>
    %15 = stablehlo.convert %c_5 : tensor<i32>
    %16 = stablehlo.subtract %14, %15 : tensor<i32>
    %17 = stablehlo.dynamic_update_slice %cst, %13, %16 : (tensor<2xf64>, tensor<1xf64>, tensor<i32>) -> tensor<2xf64>
    %18 = stablehlo.transpose %17, dims = [0] : (tensor<2xf64>) -> tensor<2xf64>
    %19 = stablehlo.reshape %18 : (tensor<2xf64>) -> tensor<2xf64>
    %20 = stablehlo.transpose %19, dims = [0] : (tensor<2xf64>) -> tensor<2xf64>
    %c_6 = stablehlo.constant dense<2> : tensor<i64>
    %21 = stablehlo.convert %c_6 : (tensor<i64>) -> tensor<i32>
    %c_7 = stablehlo.constant dense<1> : tensor<i32>
    %22 = stablehlo.convert %c_7 : tensor<i32>
    %23 = stablehlo.subtract %21, %22 : tensor<i32>
    %c_8 = stablehlo.constant dense<2> : tensor<i64>
    %24 = stablehlo.convert %c_8 : (tensor<i64>) -> tensor<i32>
    %c_9 = stablehlo.constant dense<1> : tensor<i32>
    %25 = stablehlo.convert %c_9 : tensor<i32>
    %26 = stablehlo.subtract %24, %25 : tensor<i32>
    %27 = stablehlo.dynamic_slice %0, %23, %26, sizes = [1, 1] : (tensor<2x2xf64>, tensor<i32>, tensor<i32>) -> tensor<1x1xf64>
    %28 = stablehlo.transpose %27, dims = [1, 0] : (tensor<1x1xf64>) -> tensor<1x1xf64>
    %29 = stablehlo.reshape %28 : (tensor<1x1xf64>) -> tensor<f64>
    %30 = stablehlo.transpose %29, dims = [] : (tensor<f64>) -> tensor<f64>
    %31 = stablehlo.broadcast_in_dim %30, dims = [] : (tensor<f64>) -> tensor<1xf64>
    %c_10 = stablehlo.constant dense<2> : tensor<i64>
    %32 = stablehlo.convert %c_10 : (tensor<i64>) -> tensor<i32>
    %c_11 = stablehlo.constant dense<1> : tensor<i32>
    %33 = stablehlo.convert %c_11 : tensor<i32>
    %34 = stablehlo.subtract %32, %33 : tensor<i32>
    %35 = stablehlo.dynamic_update_slice %20, %31, %34 : (tensor<2xf64>, tensor<1xf64>, tensor<i32>) -> tensor<2xf64>
    %36 = stablehlo.transpose %35, dims = [0] : (tensor<2xf64>) -> tensor<2xf64>
    %37 = stablehlo.reshape %36 : (tensor<2xf64>) -> tensor<2xf64>
    %38 = stablehlo.transpose %37, dims = [0] : (tensor<2xf64>) -> tensor<2xf64>
    %cst_12 = stablehlo.constant dense<0.000000e+00> : tensor<f64>
    %39 = stablehlo.convert %cst_12 : tensor<f64>
    %40 = enzyme.batch @identity_broadcast_scalar(%38) {batch_shape = array<i64: 2>} : (tensor<2xf64>) -> tensor<2xf64>
    %c_13 = stablehlo.constant dense<0> : tensor<i64>
    %41 = stablehlo.convert %c_13 : (tensor<i64>) -> tensor<f64>
    %c_14 = stablehlo.constant dense<0> : tensor<i64>
    %42 = stablehlo.convert %c_14 : (tensor<i64>) -> tensor<f64>
    %43 = stablehlo.reduce(%40 init: %39) applies stablehlo.add across dimensions = [0] : (tensor<2xf64>, tensor<f64>) -> tensor<f64>
    %44 = stablehlo.transpose %43, dims = [] : (tensor<f64>) -> tensor<f64>
    %45 = stablehlo.reshape %44 : (tensor<f64>) -> tensor<1xf64>
    %46 = stablehlo.transpose %45, dims = [0] : (tensor<1xf64>) -> tensor<1xf64>
    %47 = stablehlo.broadcast_in_dim %46, dims = [0] : (tensor<1xf64>) -> tensor<1xf64>
    %48 = stablehlo.broadcast_in_dim %47, dims = [0] : (tensor<1xf64>) -> tensor<1xf64>
    %49:2 = enzyme.batch @"*_broadcast_scalar"(%48) {batch_shape = array<i64: 1>} : (tensor<1xf64>) -> (tensor<1xf64>, tensor<1xf64>)
    %50 = stablehlo.transpose %49#0, dims = [0] : (tensor<1xf64>) -> tensor<1xf64>
    %51 = stablehlo.reshape %50 : (tensor<1xf64>) -> tensor<f64>
    %52 = stablehlo.transpose %51, dims = [] : (tensor<f64>) -> tensor<f64>
    %53 = stablehlo.transpose %0, dims = [1, 0] : (tensor<2x2xf64>) -> tensor<2x2xf64>
    return %52, %53 : tensor<f64>, tensor<2x2xf64>
  }
}

@GiggleLiu
Copy link

Hi @mofeing, thanks for adding Reactant support for OMEinsum!

The test coverage looks good. A few additional cases you might consider:

  1. N-ary contractions: ein"ia,ib,ic->abc" (3+ tensor contraction)
  2. Star contraction: ein"ai,bi,ci->abc" (multiple tensors sharing one index)
  3. Partial trace: ein"iijk->jk" (trace over subset of repeated indices)
  4. Hyper-index cases: ein"iii->i" (index repeated 3+ times)
  5. Mixed patterns: ein"ijk,jk->i" (contraction + implicit sum)

These cover some tensor network patterns that come up in practice. Let me know if you run into any issues!

@mofeing
Copy link
Collaborator Author

mofeing commented Jan 31, 2026

  • N-ary contractions: ein"ia,ib,ic->abc" (3+ tensor contraction)
  • Star contraction: ein"ai,bi,ci->abc" (multiple tensors sharing one index)

I don't see a difference between these 2 cases. Also, does OMEinsum contract them together in the same routine? It doesn't decompose it into ein"ai,bi->abi" and then ein"abi,ci->abc"?

  • Partial trace: ein"iijk->jk" (trace over subset of repeated indices)

Trace is currently broken due to a bug in a pass, but more importantly, it currently unrolls the for loop in _compactify! and generates a loot of code. I might need to add a custom version of _compactify! for this extension.

Also, the current "trace" test doesn't seem to trigger the Tr rule of unary_einsum!, but Diag rule instead. Is that correct?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants