Skip to content

Commit 1000ed5

Browse files
sergachevGoogle-ML-Automation
authored andcommitted
PR #19571: PJRT: assign process index and count for compilation using device assignment.
Imported from GitHub PR #19571 Only a subset of processes may be participating in the compilation of a module. Copybara import of the project: -- 15250fc by Ilia Sergachev <[email protected]>: PJRT: assign process index and count for compilation using device assignment. Only a subset of processes may be participating in the compilation of a module. -- 8620919 by Ilia Sergachev <[email protected]>: fix functional_hlo_runner_test Merging this change closes #19571 COPYBARA_INTEGRATE_REVIEW=#19571 from openxla:pjrt_fix_process_index_count 8620919 PiperOrigin-RevId: 702231769
1 parent a76d5b4 commit 1000ed5

File tree

6 files changed

+48
-16
lines changed

6 files changed

+48
-16
lines changed

xla/pjrt/gpu/se_gpu_pjrt_client_test.cc

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1840,20 +1840,27 @@ TEST(StreamExecutorGpuClientTest, AutoLayoutIsSupported) {
18401840
EXPECT_NE(layouts[1]->ToString(), "{2,1,0}");
18411841
}
18421842

1843-
class ShardedAutotuningTest : public ::testing::TestWithParam<bool> {
1843+
class ShardedAutotuningTest
1844+
: public ::testing::TestWithParam<std::tuple<bool, int>> {
18441845
public:
18451846
static constexpr int kNumNodes = 2;
18461847
};
18471848

18481849
static const char* test_binary_name;
18491850

18501851
TEST_P(ShardedAutotuningTest, ShardedAutotuningWorks) {
1852+
bool use_xla_computation;
1853+
int num_active_nodes;
1854+
std::tie(use_xla_computation, num_active_nodes) = GetParam();
1855+
18511856
tsl::SubProcess child[ShardedAutotuningTest::kNumNodes];
18521857
for (int node_id = 0; node_id < ShardedAutotuningTest::kNumNodes; ++node_id) {
18531858
std::vector<std::string> argv;
18541859
argv.push_back(test_binary_name);
18551860
argv.push_back(absl::StrFormat("--node_id=%d", node_id));
1856-
argv.push_back(absl::StrFormat("--use_xla_computation=%d", GetParam()));
1861+
argv.push_back(
1862+
absl::StrFormat("--use_xla_computation=%d", use_xla_computation));
1863+
argv.push_back(absl::StrFormat("--num_active_nodes=%d", num_active_nodes));
18571864
child[node_id].SetProgram(test_binary_name, argv);
18581865
child[node_id].SetChannelAction(tsl::CHAN_STDOUT, tsl::ACTION_PIPE);
18591866
child[node_id].SetChannelAction(tsl::CHAN_STDERR, tsl::ACTION_PIPE);
@@ -1876,6 +1883,7 @@ TEST_P(ShardedAutotuningTest, ShardedAutotuningWorks) {
18761883
}
18771884

18781885
absl::Status ShardedAutotuningWorksTestBody(const int node_id,
1886+
const int num_active_nodes,
18791887
bool use_xla_computation) {
18801888
std::unique_ptr<xla::DistributedRuntimeService> service;
18811889
if (node_id == 0) {
@@ -1911,6 +1919,11 @@ absl::Status ShardedAutotuningWorksTestBody(const int node_id,
19111919
TF_RET_CHECK(client->addressable_device_count() == 1);
19121920
TF_RET_CHECK(client->device_count() == ShardedAutotuningTest::kNumNodes);
19131921

1922+
if (node_id >= num_active_nodes) {
1923+
// Inactive nodes connect to the coordination service but don't compile.
1924+
return absl::OkStatus();
1925+
}
1926+
19141927
CompileOptions compile_options;
19151928
DebugOptions* debug_options =
19161929
compile_options.executable_build_options.mutable_debug_options();
@@ -1951,8 +1964,11 @@ absl::Status ShardedAutotuningWorksTestBody(const int node_id,
19511964
return absl::OkStatus();
19521965
}
19531966

1954-
INSTANTIATE_TEST_SUITE_P(ShardedAutotuningTest, ShardedAutotuningTest,
1955-
::testing::Values(false, true));
1967+
INSTANTIATE_TEST_SUITE_P(
1968+
ShardedAutotuningTest, ShardedAutotuningTest,
1969+
::testing::Combine(::testing::Bool(),
1970+
::testing::Range(1,
1971+
ShardedAutotuningTest::kNumNodes + 1)));
19561972

19571973
} // namespace
19581974
} // namespace xla
@@ -1961,10 +1977,13 @@ int main(int argc, char* argv[]) {
19611977
// Save name of binary so that it may invoke itself.
19621978
xla::test_binary_name = argv[0];
19631979
int node_id = -1;
1980+
int num_active_nodes = -1;
19641981
bool use_xla_computation = false;
19651982
std::vector<tsl::Flag> flag_list = {
19661983
tsl::Flag("node_id", &node_id,
19671984
"Node ID for ShardedAutotuningWorks test."),
1985+
tsl::Flag("num_active_nodes", &num_active_nodes,
1986+
"Test parameter for ShardedAutotuningWorks."),
19681987
tsl::Flag("use_xla_computation", &use_xla_computation,
19691988
"Test parameter for ShardedAutotuningWorks."),
19701989
};
@@ -1973,7 +1992,8 @@ int main(int argc, char* argv[]) {
19731992
tsl::Flags::Parse(&argc, argv, flag_list);
19741993
testing::InitGoogleTest(&argc, argv);
19751994
if (node_id >= 0) {
1976-
return xla::ShardedAutotuningWorksTestBody(node_id, use_xla_computation)
1995+
return xla::ShardedAutotuningWorksTestBody(node_id, num_active_nodes,
1996+
use_xla_computation)
19771997
.raw_code();
19781998
}
19791999
return RUN_ALL_TESTS();

xla/pjrt/pjrt_stream_executor_client.cc

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3481,17 +3481,23 @@ PjRtStreamExecutorClient::GetExecutableExtras(CompileOptions* options) {
34813481
if (device_assignment != nullptr) {
34823482
addressable_device_logical_ids.reserve(num_replicas * num_partitions);
34833483
addressable_devices.reserve(num_replicas * num_partitions);
3484+
absl::flat_hash_set<int> all_process_indices;
3485+
std::optional<int> this_process_index;
34843486
for (int replica = 0; replica < num_replicas; ++replica) {
34853487
for (int partition = 0; partition < num_partitions; ++partition) {
34863488
int64_t device_id = (*device_assignment)(replica, partition);
34873489
PjRtGlobalDeviceId global_device_id(device_id);
34883490

34893491
TF_ASSIGN_OR_RETURN(PjRtDevice * device,
34903492
LookupDevice(global_device_id));
3493+
all_process_indices.insert(device->process_index());
34913494
if (device->process_index() != process_index()) {
34923495
VLOG(3) << "Non-local device: " << device_id;
34933496
continue;
34943497
}
3498+
if (!this_process_index.has_value()) {
3499+
this_process_index = all_process_indices.size() - 1;
3500+
}
34953501
PjRtLoadedExecutable::LogicalDeviceIds logica_device_ids;
34963502
logica_device_ids.replica = replica;
34973503
logica_device_ids.partition = partition;
@@ -3509,6 +3515,9 @@ PjRtStreamExecutorClient::GetExecutableExtras(CompileOptions* options) {
35093515
build_options.set_device_ordinal(
35103516
addressable_devices.front()->local_hardware_id().value());
35113517
}
3518+
3519+
build_options.set_process_index(*this_process_index);
3520+
build_options.set_process_count(all_process_indices.size());
35123521
}
35133522
return extras;
35143523
}
@@ -3525,11 +3534,6 @@ PjRtStreamExecutorClient::CompileInternal(
35253534
!options.executable_build_options.key_value_store()) {
35263535
options.executable_build_options.set_key_value_store(*key_value_store());
35273536
}
3528-
options.executable_build_options.set_process_index(process_index());
3529-
TF_RET_CHECK(device_count() % addressable_device_count() == 0)
3530-
<< "Each process is expected to have the same number of devices";
3531-
options.executable_build_options.set_process_count(
3532-
device_count() / addressable_device_count());
35333537
auto input_options = options;
35343538

35353539
TF_RETURN_IF_ERROR(options.ApplyAllOptionOverrides());

xla/tools/multihost_hlo_runner/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ xla_test(
200200
":create_client",
201201
":functional_hlo_runner",
202202
"//xla:debug_options_flags",
203+
"//xla:status_macros",
203204
"//xla:xla_proto_cc",
204205
"//xla/hlo/testlib:filecheck",
205206
"//xla/pjrt:pjrt_client",

xla/tools/multihost_hlo_runner/functional_hlo_runner.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,8 @@ absl::Status FunctionalHloRunner::LoadAndCompile(
544544
const PreprocessingOptions& preproc_options,
545545
const RawCompileOptions& raw_compile_options, std::string_view hlo_file,
546546
InputFormat input_format, int task_id, int num_nodes,
547-
std::shared_ptr<xla::KeyValueStoreInterface> kv_store) {
547+
std::shared_ptr<xla::KeyValueStoreInterface> kv_store,
548+
bool use_gpu_count_workaround) {
548549
TF_ASSIGN_OR_RETURN(
549550
CompileOptions compile_options,
550551
FunctionalHloRunner::CreateCompileOptions(client, raw_compile_options,
@@ -554,7 +555,8 @@ absl::Status FunctionalHloRunner::LoadAndCompile(
554555
int num_partitions =
555556
compile_options.executable_build_options.num_partitions();
556557
int needed_devices = num_replicas * num_partitions;
557-
if (client.addressable_device_count() < needed_devices) {
558+
if (client.addressable_device_count() < needed_devices &&
559+
use_gpu_count_workaround) {
558560
LOG(INFO) << "Applying a workaround to allow compiling multi-device HLOs "
559561
"on machines with fewer devices.";
560562
DeviceAssignment assignment(num_replicas, num_partitions);

xla/tools/multihost_hlo_runner/functional_hlo_runner.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,8 @@ class FunctionalHloRunner {
270270
const PreprocessingOptions& preproc_options,
271271
const RawCompileOptions& raw_compile_options, std::string_view hlo_file,
272272
InputFormat input_format, int task_id = 0, int num_nodes = 1,
273-
std::shared_ptr<xla::KeyValueStoreInterface> kv_store = nullptr);
273+
std::shared_ptr<xla::KeyValueStoreInterface> kv_store = nullptr,
274+
bool use_gpu_count_workaround = true);
274275

275276
// Compiles and runs the given HLO module with the given arguments for each
276277
// device. The given arguments is a map from device ID to a list of arguments.

xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ limitations under the License.
3030
#include "xla/hlo/testlib/filecheck.h"
3131
#include "xla/pjrt/pjrt_client.h"
3232
#include "xla/pjrt/plugin/xla_gpu/xla_gpu_client_options.h"
33+
#include "xla/status_macros.h"
3334
#include "xla/tools/multihost_hlo_runner/create_client.h"
3435
#include "xla/tsl/lib/core/status_test_util.h"
3536
#include "xla/tsl/util/command_line_flags.h"
@@ -296,7 +297,9 @@ absl::Status ShardedAutotuningWorksTestBody(const int node_id) {
296297
PjRtEnvironment env,
297298
xla::GetPjRtEnvironmentForGpu("127.0.0.1:12345", gpu_options,
298299
/*init_timeout=*/absl::Seconds(120)));
299-
CHECK(env.kv_store != nullptr);
300+
TF_RET_CHECK(env.kv_store != nullptr);
301+
TF_RET_CHECK(env.client->device_count() == kNumNodes);
302+
TF_RET_CHECK(env.client->addressable_device_count() == 1);
300303
// Make HLO module IDs of multiple_gemm_fusions.hlo differ: the autotuner
301304
// should not rely on them.
302305
if (node_id == 0) {
@@ -310,9 +313,10 @@ absl::Status ShardedAutotuningWorksTestBody(const int node_id) {
310313
TF_RETURN_IF_ERROR(FunctionalHloRunner::LoadAndCompile(
311314
*env.client, GetDebugOptionsFromFlags(),
312315
FunctionalHloRunner::PreprocessingOptions{},
313-
FunctionalHloRunner::RawCompileOptions{},
316+
FunctionalHloRunner::RawCompileOptions{.num_replicas = kNumNodes},
314317
GetHloPath(absl::StrFormat("multiple_gemm_fusions_%d.hlo", node_id + 1)),
315-
InputFormat::kText));
318+
InputFormat::kText, node_id, kNumNodes, /*kv_store=*/nullptr,
319+
/*use_gpu_count_workaround=*/false));
316320
if (node_id == 0) {
317321
TF_ASSIGN_OR_RETURN(
318322
std::string results0,

0 commit comments

Comments
 (0)