Skip to content

Commit 2e59711

Browse files
EetusjoapivovarovGoogle-ML-Automationmgoldfarb-nvidiaamd-songpiao
authored
Backport python multihost hlo runner (#402)
* Add `requirements_lock_3_12.txt` to Copybara config. Include the Python 3.12 requirements lock file in the files managed by Copybara for XLA. PiperOrigin-RevId: 810924536 (cherry picked from commit d8834a1) * Use a unique launch ID for every execution. Reverts 605332e PiperOrigin-RevId: 799658112 (cherry picked from commit 9d3df67) * Moves the profile uploading to be before FetchAndLogOutput so no trailing MemcpyD2H will be included in the trace data PiperOrigin-RevId: 802910040 (cherry picked from commit 6f1b953) * Allow hlo runner to profile multiple repeats. - If `num_repeats_with_profiler=3 recreate_profiler_session_between_repeats=false`, then a single profiling session is created for the last 3 repeats - If `num_repeats_with_profiler=3 recreate_profiler_session_between_repeats=false`, then we profile the last 3 repeats with 3 separated profiling sessions. PiperOrigin-RevId: 802933912 (cherry picked from commit e0f5b99) * PR openxla#30706: Expose Multi-Host HLO Runner in Python Imported from GitHub PR openxla#30706 📝 Summary of Changes Exposes the multi-host runner via nanobind interface for calling by Python programs that register custom calls. 🎯 Justification HLOs containing custom calls are not executable because the custom call targets are not linked. This change provides a straightforward path by allowing for registration of calls from python. 🚀 Kind of Contribution ✨ New Feature Copybara import of the project: -- de1a373 by Michael Goldfarb <[email protected]>: Expose multihost runner to python. -- 797ee1c by Michael Goldfarb <[email protected]>: Cleanups. -- 2271761 by Michael Goldfarb <[email protected]>: Add type registration. -- d67cf0d by Michael Goldfarb <[email protected]>: Remove ns. -- 3b8f477 by Michael Goldfarb <[email protected]>: remove ffi registration from runner. -- e121e98 by Michael Goldfarb <[email protected]>: Add back python registration code. Merging this change closes openxla#30706 COPYBARA_INTEGRATE_REVIEW=openxla#30706 from mgoldfarb-nvidia:mgoldfarb/multihost_runner_py e121e98 PiperOrigin-RevId: 803426356 (cherry picked from commit 09e51fb) * PR openxla#31074: Expose num_repeats_with_profiler option to Python HLO Runner interface Imported from GitHub PR openxla#31074 📝 Summary of Changes Exposes the `num_repeats_with_profiler` which was missed in the first PR. 🎯 Justification Enables profiling with more than 1 iteration. 🚀 Kind of Contribution ♻️ Cleanup Copybara import of the project: -- 8960d9b by Michael Goldfarb <[email protected]>: Expose num_repeats_with_profiler option to Python HLO Runner interface Merging this change closes openxla#31074 COPYBARA_INTEGRATE_REVIEW=openxla#31074 from mgoldfarb-nvidia:mgoldfarb-nvidia/update_hlo_runner 8960d9b PiperOrigin-RevId: 804792478 (cherry picked from commit 34386ae) * PR openxla#32009: [ROCm] fixed the build error on rocm Imported from GitHub PR openxla#32009 🐛 Bug Fix Fixed the build error on ROCm, as cupti_tracer is not available on ROCm platform. It is a separate PR according to the comment in openxla#32002 (comment). @xla-rotation could you review my PR, please? Copybara import of the project: -- d66a44e by Songlin <[email protected]>: fixed build error on rocm Merging this change closes openxla#32009 COPYBARA_INTEGRATE_REVIEW=openxla#32009 from ROCm:ci_fixbuild_multihost_hlo_runner_rocm d66a44e PiperOrigin-RevId: 812790412 (cherry picked from commit 22d1944) * PR openxla#32336: [ROCm] Move cupti_tracer to cuda dependencies in py_hlo_multihost_runner target Imported from GitHub PR openxla#32336 📝 Summary of Changes Move cupti_tracer to cuda dependencies in py_hlo_multihost_runner target 🎯 Justification This PR fixes building py_hlo_multihost_runner on ROCm, where CUPTI is not available, missed in openxla#32012 🚀 Kind of Contribution 🐛 Bug Fix @xla-rotation could I get a review for this PR, please? Copybara import of the project: -- fdd0217 by Eetu Sjöblom <[email protected]>: Move cupti_tracer to cuda dependencies Merging this change closes openxla#32336 COPYBARA_INTEGRATE_REVIEW=openxla#32336 from ROCm:ci_rocm_fix_py_hlo_runner fdd0217 PiperOrigin-RevId: 817277069 (cherry picked from commit 06a2427) --------- Co-authored-by: Alex Pivovarov <[email protected]> Co-authored-by: xla authors <[email protected]> Co-authored-by: Michael Goldfarb <[email protected]> Co-authored-by: spiao <[email protected]>
1 parent c01b39d commit 2e59711

File tree

6 files changed

+690
-13
lines changed

6 files changed

+690
-13
lines changed

requirements_lock_3_12.txt

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
numpy==2.1.3 \
2+
--hash=sha256:016d0f6f5e77b0f0d45d77387ffa4bb89816b57c835580c3ce8e099ef830befe \
3+
--hash=sha256:02135ade8b8a84011cbb67dc44e07c58f28575cf9ecf8ab304e51c05528c19f0 \
4+
--hash=sha256:08788d27a5fd867a663f6fc753fd7c3ad7e92747efc73c53bca2f19f8bc06f48 \
5+
--hash=sha256:0d30c543f02e84e92c4b1f415b7c6b5326cbe45ee7882b6b77db7195fb971e3a \
6+
--hash=sha256:0fa14563cc46422e99daef53d725d0c326e99e468a9320a240affffe87852564 \
7+
--hash=sha256:13138eadd4f4da03074851a698ffa7e405f41a0845a6b1ad135b81596e4e9958 \
8+
--hash=sha256:14e253bd43fc6b37af4921b10f6add6925878a42a0c5fe83daee390bca80bc17 \
9+
--hash=sha256:15cb89f39fa6d0bdfb600ea24b250e5f1a3df23f901f51c8debaa6a5d122b2f0 \
10+
--hash=sha256:17ee83a1f4fef3c94d16dc1802b998668b5419362c8a4f4e8a491de1b41cc3ee \
11+
--hash=sha256:2312b2aa89e1f43ecea6da6ea9a810d06aae08321609d8dc0d0eda6d946a541b \
12+
--hash=sha256:2564fbdf2b99b3f815f2107c1bbc93e2de8ee655a69c261363a1172a79a257d4 \
13+
--hash=sha256:3522b0dfe983a575e6a9ab3a4a4dfe156c3e428468ff08ce582b9bb6bd1d71d4 \
14+
--hash=sha256:4394bc0dbd074b7f9b52024832d16e019decebf86caf909d94f6b3f77a8ee3b6 \
15+
--hash=sha256:45966d859916ad02b779706bb43b954281db43e185015df6eb3323120188f9e4 \
16+
--hash=sha256:4d1167c53b93f1f5d8a139a742b3c6f4d429b54e74e6b57d0eff40045187b15d \
17+
--hash=sha256:4f2015dfe437dfebbfce7c85c7b53d81ba49e71ba7eadbf1df40c915af75979f \
18+
--hash=sha256:50ca6aba6e163363f132b5c101ba078b8cbd3fa92c7865fd7d4d62d9779ac29f \
19+
--hash=sha256:50d18c4358a0a8a53f12a8ba9d772ab2d460321e6a93d6064fc22443d189853f \
20+
--hash=sha256:5641516794ca9e5f8a4d17bb45446998c6554704d888f86df9b200e66bdcce56 \
21+
--hash=sha256:576a1c1d25e9e02ed7fa5477f30a127fe56debd53b8d2c89d5578f9857d03ca9 \
22+
--hash=sha256:6a4825252fcc430a182ac4dee5a505053d262c807f8a924603d411f6718b88fd \
23+
--hash=sha256:72dcc4a35a8515d83e76b58fdf8113a5c969ccd505c8a946759b24e3182d1f23 \
24+
--hash=sha256:747641635d3d44bcb380d950679462fae44f54b131be347d5ec2bce47d3df9ed \
25+
--hash=sha256:762479be47a4863e261a840e8e01608d124ee1361e48b96916f38b119cfda04a \
26+
--hash=sha256:78574ac2d1a4a02421f25da9559850d59457bac82f2b8d7a44fe83a64f770098 \
27+
--hash=sha256:825656d0743699c529c5943554d223c021ff0494ff1442152ce887ef4f7561a1 \
28+
--hash=sha256:8637dcd2caa676e475503d1f8fdb327bc495554e10838019651b76d17b98e512 \
29+
--hash=sha256:96fe52fcdb9345b7cd82ecd34547fca4321f7656d500eca497eb7ea5a926692f \
30+
--hash=sha256:973faafebaae4c0aaa1a1ca1ce02434554d67e628b8d805e61f874b84e136b09 \
31+
--hash=sha256:996bb9399059c5b82f76b53ff8bb686069c05acc94656bb259b1d63d04a9506f \
32+
--hash=sha256:a38c19106902bb19351b83802531fea19dee18e5b37b36454f27f11ff956f7fc \
33+
--hash=sha256:a6b46587b14b888e95e4a24d7b13ae91fa22386c199ee7b418f449032b2fa3b8 \
34+
--hash=sha256:a9f7f672a3388133335589cfca93ed468509cb7b93ba3105fce780d04a6576a0 \
35+
--hash=sha256:aa08e04e08aaf974d4458def539dece0d28146d866a39da5639596f4921fd761 \
36+
--hash=sha256:b0df3635b9c8ef48bd3be5f862cf71b0a4716fa0e702155c45067c6b711ddcef \
37+
--hash=sha256:b47fbb433d3260adcd51eb54f92a2ffbc90a4595f8970ee00e064c644ac788f5 \
38+
--hash=sha256:baed7e8d7481bfe0874b566850cb0b85243e982388b7b23348c6db2ee2b2ae8e \
39+
--hash=sha256:bc6f24b3d1ecc1eebfbf5d6051faa49af40b03be1aaa781ebdadcbc090b4539b \
40+
--hash=sha256:c006b607a865b07cd981ccb218a04fc86b600411d83d6fc261357f1c0966755d \
41+
--hash=sha256:c181ba05ce8299c7aa3125c27b9c2167bca4a4445b7ce73d5febc411ca692e43 \
42+
--hash=sha256:c7662f0e3673fe4e832fe07b65c50342ea27d989f92c80355658c7f888fcc83c \
43+
--hash=sha256:c80e4a09b3d95b4e1cac08643f1152fa71a0a821a2d4277334c88d54b2219a41 \
44+
--hash=sha256:c894b4305373b9c5576d7a12b473702afdf48ce5369c074ba304cc5ad8730dff \
45+
--hash=sha256:d7aac50327da5d208db2eec22eb11e491e3fe13d22653dce51b0f4109101b408 \
46+
--hash=sha256:d89dd2b6da69c4fff5e39c28a382199ddedc3a5be5390115608345dec660b9e2 \
47+
--hash=sha256:d9beb777a78c331580705326d2367488d5bc473b49a9bc3036c154832520aca9 \
48+
--hash=sha256:dc258a761a16daa791081d026f0ed4399b582712e6fc887a95af09df10c5ca57 \
49+
--hash=sha256:e14e26956e6f1696070788252dcdff11b4aca4c3e8bd166e0df1bb8f315a67cb \
50+
--hash=sha256:e6988e90fcf617da2b5c78902fe8e668361b43b4fe26dbf2d7b0f8034d4cafb9 \
51+
--hash=sha256:e711e02f49e176a01d0349d82cb5f05ba4db7d5e7e0defd026328e5cfb3226d3 \
52+
--hash=sha256:ea4dedd6e394a9c180b33c2c872b92f7ce0f8e7ad93e9585312b0c5a04777a4a \
53+
--hash=sha256:ecc76a9ba2911d8d37ac01de72834d8849e55473457558e12995f4cd53e778e0 \
54+
--hash=sha256:f55ba01150f52b1027829b50d70ef1dafd9821ea82905b63936668403c3b471e \
55+
--hash=sha256:f653490b33e9c3a4c1c01d41bc2aef08f9475af51146e4a7710c450cf9761598 \
56+
--hash=sha256:fa2d1337dc61c8dc417fbccf20f6d1e139896a30721b7f1e832b2bb6ef4eb6c4
57+
lit==17.0.6 \
58+
--hash=sha256:dfa9af9b55fc4509a56be7bf2346f079d7f4a242d583b9f2e0b078fd0abae31b
59+
ml-dtypes==0.5.3 \
60+
--hash=sha256:01de48de4537dc3c46e684b969a40ec36594e7eeb7c69e9a093e7239f030a28a \
61+
--hash=sha256:0a1d68a7cb53e3f640b2b6a34d12c0542da3dd935e560fdf463c0c77f339fc20 \
62+
--hash=sha256:0cd5a6c711b5350f3cbc2ac28def81cd1c580075ccb7955e61e9d8f4bfd40d24 \
63+
--hash=sha256:0e44a3761f64bc009d71ddb6d6c71008ba21b53ab6ee588dadab65e2fa79eafc \
64+
--hash=sha256:156418abeeda48ea4797db6776db3c5bdab9ac7be197c1233771e0880c304057 \
65+
--hash=sha256:19f6c3a4f635c2fc9e2aa7d91416bd7a3d649b48350c51f7f715a09370a90d93 \
66+
--hash=sha256:1b255acada256d1fa8c35ed07b5f6d18bc21d1556f842fbc2d5718aea2cd9e55 \
67+
--hash=sha256:1db60c154989af253f6c4a34e8a540c2c9dce4d770784d426945e09908fbb177 \
68+
--hash=sha256:2db74788fc01914a3c7f7da0763427280adfc9cd377e9604b6b64eb8097284bd \
69+
--hash=sha256:4a177b882667c69422402df6ed5c3428ce07ac2c1f844d8a1314944651439458 \
70+
--hash=sha256:4cae435a68861660af81fa3c5af16b70ca11a17275c5b662d9c6f58294e0f113 \
71+
--hash=sha256:5103856a225465371fe119f2fef737402b705b810bd95ad5f348e6e1a6ae21af \
72+
--hash=sha256:58e39349d820b5702bb6f94ea0cb2dc8ec62ee81c0267d9622067d8333596a46 \
73+
--hash=sha256:5ab039ffb40f3dc0aeeeba84fd6c3452781b5e15bef72e2d10bcb33e4bbffc39 \
74+
--hash=sha256:5ee72568d46b9533ad54f78b1e1f3067c0534c5065120ea8ecc6f210d22748b3 \
75+
--hash=sha256:66c2756ae6cfd7f5224e355c893cfd617fa2f747b8bbd8996152cbdebad9a184 \
76+
--hash=sha256:6936283b56d74fbec431ca57ce58a90a908fdbd14d4e2d22eea6d72bb208a7b7 \
77+
--hash=sha256:8b1a6e231b0770f2894910f1dce6d2f31d65884dbf7668f9b08d73623cdca909 \
78+
--hash=sha256:8bb9cd1ce63096567f5f42851f5843b5a0ea11511e50039a7649619abfb4ba6d \
79+
--hash=sha256:93c36a08a6d158db44f2eb9ce3258e53f24a9a4a695325a689494f0fdbc71770 \
80+
--hash=sha256:95ce33057ba4d05df50b1f3cfefab22e351868a843b3b15a46c65836283670c9 \
81+
--hash=sha256:9849ce7267444c0a717c80c6900997de4f36e2815ce34ac560a3edb2d9a64cd2 \
82+
--hash=sha256:9d55ea7f7baf2aed61bf1872116cefc9d0c3693b45cae3916897ee27ef4b835e \
83+
--hash=sha256:a4f39b9bf6555fab9bfb536cf5fdd1c1c727e8d22312078702e9ff005354b37f \
84+
--hash=sha256:aec640bd94c4c85c0d11e2733bd13cbb10438fb004852996ec0efbc6cacdaf70 \
85+
--hash=sha256:aecbd7c5272c82e54d5b99d8435fd10915d1bc704b7df15e4d9ca8dc3902be61 \
86+
--hash=sha256:bda32ce212baa724e03c68771e5c69f39e584ea426bfe1a701cb01508ffc7035 \
87+
--hash=sha256:bdcf26c2dbc926b8a35ec8cbfad7eff1a8bd8239e12478caca83a1fc2c400dc2 \
88+
--hash=sha256:bdf40d2aaabd3913dec11840f0d0ebb1b93134f99af6a0a4fd88ffe924928ab4 \
89+
--hash=sha256:c205cac07d24a29840c163d6469f61069ce4b065518519216297fc2f261f8db9 \
90+
--hash=sha256:c3f5ae0309d9f888fd825c2e9d0241102fadaca81d888f26f845bc8c13c1e4ee \
91+
--hash=sha256:cd7c0bb22d4ff86d65ad61b5dd246812e8993fbc95b558553624c33e8b6903ea \
92+
--hash=sha256:d0f730a17cf4f343b2c7ad50cee3bd19e969e793d2be6ed911f43086460096e4 \
93+
--hash=sha256:da65e5fd3eea434ccb8984c3624bc234ddcc0d9f4c81864af611aaebcc08a50e \
94+
--hash=sha256:e12e29764a0e66a7a31e9b8bf1de5cc0423ea72979f45909acd4292de834ccd3

xla/tools/multihost_hlo_runner/BUILD

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
44
load("//xla:xla.default.bzl", "xla_cc_binary")
55
load("//xla/tests:build_defs.bzl", "xla_test")
66
load("//xla/tsl:tsl.bzl", "if_cuda_or_rocm", "if_google")
7+
load("//xla/tsl:tsl.default.bzl", "tsl_pybind_extension")
78
load("//xla/tsl/platform:build_config_root.bzl", "tf_gpu_tests_tags")
89
load("//xla/tsl/platform:rules_cc.bzl", "cc_library")
910

@@ -37,9 +38,7 @@ cc_library(
3738
testonly = True,
3839
srcs = ["hlo_runner_main.cc"],
3940
compatible_with = None,
40-
tags = [
41-
"no_mac",
42-
],
41+
tags = ["no_mac"],
4342
deps = [
4443
":create_client",
4544
":functional_hlo_runner",
@@ -71,10 +70,10 @@ cc_library(
7170
"@tsl//tsl/platform:statusor",
7271
] + if_cuda_or_rocm([
7372
"//xla/service:gpu_plugin",
74-
"//xla/backends/profiler/gpu:cupti_tracer",
7573
"//xla/backends/profiler/gpu:device_tracer",
7674
]) + if_cuda([
7775
"//xla/stream_executor:cuda_platform",
76+
"//xla/backends/profiler/gpu:cupti_tracer",
7877
] + if_google(
7978
[
8079
"//third_party/py/jax/jaxlib/cuda:cuda_gpu_kernels", # fixdeps: keep
@@ -259,6 +258,7 @@ xla_test(
259258
"//xla/service:computation_layout",
260259
"//xla/service:hlo_proto_cc",
261260
"//xla/tests:xla_test_backend_predicates",
261+
"//xla/tools/multihost_hlo_runner:profiler_interface",
262262
"//xla/tsl/lib/core:status_test_util",
263263
"//xla/tsl/platform:env",
264264
"//xla/tsl/platform:errors",
@@ -279,3 +279,58 @@ xla_test(
279279
"@tsl//tsl/platform:protobuf",
280280
],
281281
)
282+
283+
tsl_pybind_extension(
284+
name = "py_hlo_multihost_runner",
285+
srcs = ["python_hlo_runner.cc"],
286+
deps = [
287+
":create_client",
288+
":functional_hlo_runner",
289+
":hlo_input_output_format",
290+
":profiler_interface",
291+
"//xla:debug_options_flags",
292+
"//xla:status_macros",
293+
"//xla:xla_data_proto_cc",
294+
"//xla/ffi",
295+
"//xla/ffi:ffi_api",
296+
"//xla/ffi/api:c_api",
297+
"//xla/pjrt:pjrt_client",
298+
"//xla/pjrt:status_casters",
299+
"//xla/pjrt/distributed",
300+
"//xla/pjrt/distributed:client",
301+
"//xla/pjrt/distributed:key_value_store_interface",
302+
"//xla/pjrt/distributed:service",
303+
"//xla/pjrt/plugin/xla_gpu:xla_gpu_allocator_config",
304+
"//xla/pjrt/plugin/xla_gpu:xla_gpu_client_options",
305+
"//xla/python:logging",
306+
"//xla/service:cpu_plugin",
307+
"//xla/service:custom_call_target_registry",
308+
"//xla/service:hlo_module_util",
309+
"//xla/tsl/platform:statusor",
310+
"//xla/tsl/util:command_line_flags",
311+
"@com_google_absl//absl/log",
312+
"@com_google_absl//absl/log:check",
313+
"@com_google_absl//absl/status",
314+
"@com_google_absl//absl/status:statusor",
315+
"@com_google_absl//absl/strings",
316+
"@com_google_absl//absl/time",
317+
"@nanobind",
318+
"@tsl//tsl/platform:errors",
319+
"@tsl//tsl/platform:logging",
320+
"@tsl//tsl/platform:platform_port",
321+
"@tsl//tsl/platform:status",
322+
"@tsl//tsl/platform:statusor",
323+
] + if_cuda_or_rocm([
324+
"//xla/service:gpu_plugin",
325+
"//xla/backends/profiler/gpu:device_tracer",
326+
]) + if_cuda([
327+
"//xla/stream_executor:cuda_platform",
328+
"//xla/backends/profiler/gpu:cupti_tracer",
329+
] + if_google(
330+
[
331+
"//third_party/py/jax/jaxlib/cuda:cuda_gpu_kernels", # fixdeps: keep
332+
],
333+
)) + if_rocm([
334+
"//xla/stream_executor:rocm_platform",
335+
]),
336+
)

xla/tools/multihost_hlo_runner/functional_hlo_runner.cc

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,15 @@ absl::StatusOr<PerDeviceLiteralVecType> RunInternal(
578578
futures.emplace();
579579
std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> device_buffers;
580580
std::vector<std::vector<PjRtBuffer*>> argument_ptrs;
581+
582+
bool has_active_profiler_session = false;
581583
for (int repeat = 0; repeat < running_options.num_repeats; ++repeat) {
584+
const bool is_last_repeat = (repeat == running_options.num_repeats - 1);
585+
const bool profile_current_repeat =
586+
(running_options.profiler != nullptr) &&
587+
(repeat >= running_options.num_repeats -
588+
running_options.num_repeats_with_profiler);
589+
582590
VLOG(1) << "FunctionalHloRunner: ExecuteOnDevices started (repeat = "
583591
<< repeat << ").";
584592
{
@@ -592,29 +600,39 @@ absl::StatusOr<PerDeviceLiteralVecType> RunInternal(
592600
flatten_arguments));
593601
argument_ptrs = CreateArgumentPointersFromDeviceBuffers(device_buffers);
594602
}
595-
if (repeat == running_options.num_repeats - 1) {
603+
if (is_last_repeat) {
596604
execute_options.untuple_result = default_untuple_result;
597-
if (running_options.profiler != nullptr) {
598-
running_options.profiler->CreateSession();
599-
}
600605
}
601-
execute_options.launch_id = repeat + 1;
606+
execute_options.launch_id = repeat + 1 + running_options.base_run_id;
602607
if (running_options.execution_profiles != nullptr) {
603608
execute_options.execution_profile =
604609
&running_options.execution_profiles->emplace_back();
605610
execute_options.execution_profile->set_warmup_run_executed(repeat > 0);
606611
}
612+
613+
if (profile_current_repeat && !has_active_profiler_session) {
614+
running_options.profiler->CreateSession();
615+
has_active_profiler_session = true;
616+
}
607617
futures->clear();
608618
TF_ASSIGN_OR_RETURN(
609619
output_buffers,
610620
executable->Execute(argument_ptrs, execute_options, futures));
611621
for (auto& future : *futures) {
612622
TF_RETURN_IF_ERROR(future.Await());
613623
}
624+
625+
const bool upload_active_profiler_session =
626+
running_options.recreate_profiler_session_between_repeats ||
627+
is_last_repeat;
628+
if (has_active_profiler_session && upload_active_profiler_session) {
629+
running_options.profiler->UploadSession();
630+
has_active_profiler_session = false;
631+
}
614632
}
615633
VLOG(1) << "FunctionalHloRunner: ExecuteOnDevices succeeded (repeat = "
616634
<< repeat << ")";
617-
if (repeat < running_options.num_repeats - 1) {
635+
if (!is_last_repeat) {
618636
switch (parameter_type) {
619637
case ParameterType::kOneTupleOfArrays:
620638
argument_ptrs = CreateArgumentPointersBasedOnAliasing(
@@ -638,9 +656,6 @@ absl::StatusOr<PerDeviceLiteralVecType> RunInternal(
638656
FetchAndLogOutput(client, output_buffers,
639657
running_options.module_output_mode,
640658
running_options.log_input_output()));
641-
if (running_options.profiler != nullptr) {
642-
running_options.profiler->UploadSession();
643-
}
644659
return results;
645660
}
646661

xla/tools/multihost_hlo_runner/functional_hlo_runner.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,13 @@ struct RunningOptions {
249249
ModuleOutputMode module_output_mode = ModuleOutputMode::kReturnOutputs;
250250
// Repeatedly execute the HLO for this many times.
251251
size_t num_repeats = 1;
252+
// The last `num_repeats_with_profiler` repeats out of `num_repeats` will be
253+
// profiled. Default is 1, i.e., the last repeat will be profiled.
254+
size_t num_repeats_with_profiler = 1;
255+
// If true, we recreate the profiler session between repeats when profiling
256+
// more than one repeat.
257+
bool recreate_profiler_session_between_repeats = false;
258+
size_t base_run_id = 0;
252259
// If true, we recreate the buffers between repeats to reset of effect of
253260
// buffer donation.
254261
bool recreate_buffers_between_repeats = false;

xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ limitations under the License.
4343
#include "xla/status_macros.h"
4444
#include "xla/tools/multihost_hlo_runner/create_client.h"
4545
#include "xla/tools/multihost_hlo_runner/hlo_input_output_format.h"
46+
#include "xla/tools/multihost_hlo_runner/profiler_interface.h"
4647
#include "xla/tsl/lib/core/status_test_util.h"
4748
#include "xla/tsl/platform/env.h"
4849
#include "xla/tsl/platform/errors.h"
@@ -868,6 +869,56 @@ TEST_F(FunctionalHloRunnerTest, DumpsUnoptimizedHLOInUnoptimizedSnapshot) {
868869
EXPECT_FALSE(snapshot.hlo_module().has_schedule());
869870
}
870871

872+
class MockProfiler : public ProfilerInterface {
873+
public:
874+
MOCK_METHOD(void, CreateSession, (), (override));
875+
MOCK_METHOD(void, UploadSession, (), (override));
876+
};
877+
878+
TEST_F(FunctionalHloRunnerTest, ProfileMultipleRepeatsSingleSession) {
879+
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::PjRtClient> client,
880+
GetPjRtClient());
881+
xla::DebugOptions debug_options;
882+
FunctionalHloRunner::PreprocessingOptions preproc_options;
883+
CompileOptions compile_options;
884+
885+
FunctionalHloRunner::RunningOptions running_options;
886+
MockProfiler mock_profiler;
887+
running_options.profiler = &mock_profiler;
888+
running_options.num_repeats = 5;
889+
running_options.num_repeats_with_profiler = 3;
890+
running_options.recreate_profiler_session_between_repeats = false;
891+
892+
EXPECT_CALL(mock_profiler, CreateSession()).Times(1);
893+
EXPECT_CALL(mock_profiler, UploadSession()).Times(1);
894+
895+
TF_EXPECT_OK(FunctionalHloRunner::LoadAndRun(
896+
*client, debug_options, preproc_options, compile_options, running_options,
897+
{GetHloPath("single_device.hlo")}, InputFormat::kText));
898+
}
899+
900+
TEST_F(FunctionalHloRunnerTest, ProfileMultipleRepeatsSessionPerRepeat) {
901+
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::PjRtClient> client,
902+
GetPjRtClient());
903+
xla::DebugOptions debug_options;
904+
FunctionalHloRunner::PreprocessingOptions preproc_options;
905+
CompileOptions compile_options;
906+
907+
FunctionalHloRunner::RunningOptions running_options;
908+
MockProfiler mock_profiler;
909+
running_options.profiler = &mock_profiler;
910+
running_options.num_repeats = 5;
911+
running_options.num_repeats_with_profiler = 3;
912+
running_options.recreate_profiler_session_between_repeats = true;
913+
914+
EXPECT_CALL(mock_profiler, CreateSession()).Times(3);
915+
EXPECT_CALL(mock_profiler, UploadSession()).Times(3);
916+
917+
TF_EXPECT_OK(FunctionalHloRunner::LoadAndRun(
918+
*client, debug_options, preproc_options, compile_options, running_options,
919+
{GetHloPath("single_device.hlo")}, InputFormat::kText));
920+
}
921+
871922
} // namespace
872923
} // namespace xla
873924

0 commit comments

Comments
 (0)