Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions jax/_src/pallas/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,12 @@ py_library(
"//jax/_src/lib",
] + py_deps("numpy"),
)

py_test(
name = "cost_estimate_test",
srcs = ["test.py"],
deps = [
":pallas",
"//jax",
] + py_deps("absl/testing"),
)
9 changes: 6 additions & 3 deletions jax/_src/pallas/cost_estimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,11 +204,14 @@ def dot_general_cost_rule(ctx: Context,
assert len(lhs_contracting_dims) == len(rhs_contracting_dims)
assert len(lhs_batch_dims) == len(rhs_batch_dims)
flops = 1
# Flops along a contracting dim is 2*dim (addition and multiplication)

contracting_size=1
for i in range(len(lhs_contracting_dims)):
lhs_dim, rhs_dim = lhs_contracting_dims[i], rhs_contracting_dims[i]
lhs_dim, rhs_dim=lhs_contracting_dims[i], rhs_contracting_dims[i]
assert x_shape[lhs_dim] == y_shape[rhs_dim]
flops *= 2 * x_shape[lhs_dim]
contracting_size *= x_shape[lhs_dim]

flops *= 2 * contracting_size
# Now we handle all other dimensions.
for i, lhs_dim in enumerate(x_shape):
if i in lhs_contracting_dims:
Expand Down
63 changes: 63 additions & 0 deletions jax/_src/pallas/cost_estimate_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for cost_estimate.py."""

from absl.testing import absltest
import jax
import jax.numpy as jnp
import jax.experimental.pallas as pl


class CostEstimateTest(absltest.TestCase):
"""Tests for Pallas cost estimation."""

def test_dot_general_single_contracting_dim(self):
"""Test FLOP counting with single contracting dimension."""

def matmul(x, y):
return jnp.einsum("mk,kn->mn", x, y)

x = jax.random.normal(jax.random.key(0), (64, 128))
y = jax.random.normal(jax.random.key(1), (128, 256))

xla_flops = jax.jit(matmul).lower(x, y).compile().cost_analysis()["flops"]
pl_flops = pl.estimate_cost(matmul, x, y).flops

self.assertEqual(xla_flops, pl_flops)

def test_dot_general_multiple_contracting_dims(self):
"""Test FLOP counting with multiple contracting dimensions.

This is a regression test for https://github.com/jax-ml/jax/issues/33388
where FLOPs were incorrectly doubled for each contracting dimension.
"""

def test(x, y):
return jnp.einsum("...mk,...kn->mn", x, y)

x = jax.random.normal(jax.random.key(0), (2, 64, 128))
y = jax.random.normal(jax.random.key(1), (2, 128, 256))

xla_flops = jax.jit(test).lower(x, y).compile().cost_analysis()["flops"]
pl_flops = pl.estimate_cost(test, x, y).flops

# Expected: 64 * 256 * (2 * 128) * 2 = 8,388,608
expected = 64 * 256 * 2 * 128 * 2

self.assertEqual(xla_flops, expected)
self.assertEqual(pl_flops, expected)


if __name__ == "__main__":
absltest.main()
63 changes: 63 additions & 0 deletions jax/_src/pallas/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for cost_estimate.py."""

from absl.testing import absltest
import jax
import jax.numpy as jnp
import jax.experimental.pallas as pl


class CostEstimateTest(absltest.TestCase):
"""Tests for Pallas cost estimation."""

def test_dot_general_single_contracting_dim(self):
"""Test FLOP counting with single contracting dimension."""

def matmul(x, y):
return jnp.einsum("mk,kn->mn", x, y)

x = jax.random.normal(jax.random.key(0), (64, 128))
y = jax.random.normal(jax.random.key(1), (128, 256))

xla_flops = jax.jit(matmul).lower(x, y).compile().cost_analysis()["flops"]
pl_flops = pl.estimate_cost(matmul, x, y).flops

self.assertEqual(xla_flops, pl_flops)

def test_dot_general_multiple_contracting_dims(self):
"""Test FLOP counting with multiple contracting dimensions.
This is a regression test for https://github.com/jax-ml/jax/issues/33388
where FLOPs were incorrectly doubled for each contracting dimension.
"""

def test(x, y):
return jnp.einsum("...mk,...kn->mn", x, y)

x = jax.random.normal(jax.random.key(0), (2, 64, 128))
y = jax.random.normal(jax.random.key(1), (2, 128, 256))

xla_flops = jax.jit(test).lower(x, y).compile().cost_analysis()["flops"]
pl_flops = pl.estimate_cost(test, x, y).flops

# Expected: 64 * 256 * (2 * 128) * 2 = 8,388,608
expected = 64 * 256 * 2 * 128 * 2

self.assertEqual(xla_flops, expected)
self.assertEqual(pl_flops, expected)


if __name__ == "__main__":
absltest.main()