Skip to content

Commit df36c29

Browse files
zacmustinGoogle-ML-Automation
authored andcommitted
Compute cost-analysis on only one HLO module.
There was historically a goal to support multiple HLOs in an executable, but this work was never finished and is no longer planned so we don't need this support. This will soon enable us to return only a dict, instead of a list of dicts with only one item. PiperOrigin-RevId: 711477481
1 parent 800f903 commit df36c29

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

jax/_src/stages.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -249,8 +249,8 @@ def as_text(self) -> str:
249249
else:
250250
raise
251251

252-
# TODO(skyewm): this should return a single dict (I think returning a list
253-
# was to support MPMD executables, which never fully landed)
252+
# TODO(b/384741132): this should return a single dict (I think returning a list
253+
# was to support MPMD executables, which never fully landed).
254254
def cost_analysis(self) -> list[dict[str, float]]:
255255
xla_ext_exe = self.xla_extension_executable()
256256

@@ -266,9 +266,19 @@ def cost_analysis(self) -> list[dict[str, float]]:
266266
# Try client method if executable cost_analysis method is unimplemented
267267
if hasattr(xla_ext_exe, "client"):
268268
try:
269+
# TODO(b/384741132): We expect that the executable has only one
270+
# HloModule. We should be able to remove this check once we update the
271+
# Executable class to have only a single HloModule (see bug).
272+
hlo_modules = xla_ext_exe.hlo_modules()
273+
assert len(hlo_modules) == 1, (
274+
f"Exectuable should have only one HloModule ({len(hlo_modules)})"
275+
" were found)."
276+
)
277+
269278
return [
270-
xla_extension.hlo_module_cost_analysis(xla_ext_exe.client, m)
271-
for m in xla_ext_exe.hlo_modules()
279+
xla_extension.hlo_module_cost_analysis(
280+
xla_ext_exe.client, hlo_modules[0]
281+
)
272282
]
273283
except xla_extension.XlaRuntimeError as e:
274284
msg, *_ = e.args

0 commit comments

Comments
 (0)