-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Description
import jax
import jax.numpy as jnp
import jax.experimental.pallas as pl
def test(x, y):
return jnp.einsum('...mk,...kn->mn', x, y)
x = jax.random.normal(jax.random.key(0), (2, 64, 128))
y = jax.random.normal(jax.random.key(1), (2, 128, 256))
lowered = jax.jit(test).lower(x, y)
print(lowered.as_text())
print('xla', lowered.compile().cost_analysis()['flops'])
print('pl', pl.estimate_cost(test, x, y).flops)module @jit_test attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<2x64x128xf32>, %arg1: tensor<2x128x256xf32>) -> (tensor<64x256xf32> {jax.result_info = "result"}) {
%0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [0, 2] x [0, 1], precision = [DEFAULT, DEFAULT] : (tensor<2x64x128xf32>, tensor<2x128x256xf32>) -> tensor<64x256xf32>
return %0 : tensor<64x256xf32>
}
}
xla 8388608.0
pl 16777216
FLOPs doubles for each contracting dim when it should only do so once.
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.8.1rc5
jaxlib: 0.8.1rc5
numpy: 2.3.5
python: 3.11.11 (main, Dec 4 2024, 08:55:07) [GCC 11.4.0]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='9ddccd46aa39', release='6.6.56+', version='#1 SMP PREEMPT_DYNAMIC Sun Nov 10 10:07:59 UTC 2024', machine='x86_64')
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working