Skip to content

dot_general_cost_rule incorrect for multiple contracting dims #33388

@sssshhhhhh

Description

@sssshhhhhh

Description

import jax
import jax.numpy as jnp
import jax.experimental.pallas as pl

def test(x, y):
    return jnp.einsum('...mk,...kn->mn', x, y)

x = jax.random.normal(jax.random.key(0), (2, 64, 128))
y = jax.random.normal(jax.random.key(1), (2, 128, 256))
lowered = jax.jit(test).lower(x, y)
print(lowered.as_text())
print('xla', lowered.compile().cost_analysis()['flops'])
print('pl', pl.estimate_cost(test, x, y).flops)
module @jit_test attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<2x64x128xf32>, %arg1: tensor<2x128x256xf32>) -> (tensor<64x256xf32> {jax.result_info = "result"}) {
    %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [0, 2] x [0, 1], precision = [DEFAULT, DEFAULT] : (tensor<2x64x128xf32>, tensor<2x128x256xf32>) -> tensor<64x256xf32>
    return %0 : tensor<64x256xf32>
  }
}
xla 8388608.0
pl 16777216

FLOPs doubles for each contracting dim when it should only do so once.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.8.1rc5
jaxlib: 0.8.1rc5
numpy:  2.3.5
python: 3.11.11 (main, Dec  4 2024, 08:55:07) [GCC 11.4.0]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='9ddccd46aa39', release='6.6.56+', version='#1 SMP PREEMPT_DYNAMIC Sun Nov 10 10:07:59 UTC 2024', machine='x86_64')

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions