Skip to content

Commit 2e19703

Browse files
olegshyshkovGoogle-ML-Automation
authored andcommitted
[XLA:GPU] Remove unused ExecuteReplicated that require OpaqueExecutable.
PiperOrigin-RevId: 839297753
1 parent f4a264c commit 2e19703

File tree

5 files changed

+37
-139
lines changed

5 files changed

+37
-139
lines changed

xla/hlo/testlib/hlo_hardware_independent_test_base.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ limitations under the License.
3737
#include "xla/debug_options_flags.h"
3838
#include "xla/hlo/ir/hlo_instruction.h"
3939
#include "xla/hlo/ir/hlo_module.h"
40-
#include "xla/hlo/ir/hlo_module_group.h"
4140
#include "xla/hlo/ir/hlo_opcode.h"
41+
#include "xla/hlo/ir/hlo_print_options.h"
4242
#include "xla/hlo/parser/hlo_parser.h"
4343
#include "xla/hlo/pass/hlo_pass_interface.h"
4444
#include "xla/hlo/testlib/filecheck.h"

xla/tests/collective_ops_e2e_test.cc

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2792,27 +2792,23 @@ TEST_F(CollectiveOpsTestE2E, OptimizedSubByteAllGatherOnDim0OutputIsCorrect) {
27922792
e {
27932793
a = s4[2,4]{1,0:E(4)} constant({{0,1,2,3},{4,5,5,4}})
27942794
b = s4[4,4]{1,0:E(4)} all-gather(a), dimensions={0}
2795-
})"));
2795+
})",
2796+
kNumReplicas));
27962797

2797-
TF_ASSERT_OK_AND_ASSIGN(auto executable, hlo_runner_->CreateExecutable(
2798-
std::move(unoptimized_module),
2799-
/*run_hlo_passes=*/true));
2800-
2801-
TF_ASSERT_OK_AND_ASSIGN(const HloModule* const module,
2802-
hlo_runner_->HloModuleFromWrapped(executable.get()));
2798+
TF_ASSERT_OK_AND_ASSIGN(ExecutionResult execution_result,
2799+
ExecuteReplicated(std::move(unoptimized_module)));
28032800

2801+
const HloModule* module = execution_result.optimized_module;
28042802
EXPECT_THAT(module->entry_computation()->root_instruction(),
28052803
GmockMatch(m::Bitcast(m::AllGatherDone().WithShape(S8, {4, 2}))));
28062804

2807-
TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> result,
2808-
ExecuteReplicated(executable.get(), kNumReplicas));
2809-
28102805
const Literal expected_result =
28112806
LiteralUtil::CreateR2<s4>({{s4(0), s4(1), s4(2), s4(3)},
28122807
{s4(4), s4(5), s4(5), s4(4)},
28132808
{s4(0), s4(1), s4(2), s4(3)},
28142809
{s4(4), s4(5), s4(5), s4(4)}});
28152810

2811+
const std::vector<Literal>& result = execution_result.results;
28162812
ASSERT_EQ(result.size(), kNumReplicas);
28172813
for (int i = 0; i < kNumReplicas; ++i) {
28182814
EXPECT_TRUE(LiteralTestUtil::Equal(expected_result, result[i]))
@@ -2833,30 +2829,26 @@ TEST_F(CollectiveOpsTestE2E, OptimizedSubByteAllGatherOnDim1OutputIsCorrect) {
28332829
e {
28342830
a = s4[4,2]{1,0:E(4)} constant({{0,1},{2,3},{4,5},{5,4}})
28352831
b = s4[4,4]{1,0:E(4)} all-gather(a), dimensions={1}
2836-
})"));
2832+
})",
2833+
kNumReplicas));
28372834

2838-
TF_ASSERT_OK_AND_ASSIGN(auto executable, hlo_runner_->CreateExecutable(
2839-
std::move(unoptimized_module),
2840-
/*run_hlo_passes=*/true));
2841-
2842-
TF_ASSERT_OK_AND_ASSIGN(const HloModule* const module,
2843-
hlo_runner_->HloModuleFromWrapped(executable.get()));
2835+
TF_ASSERT_OK_AND_ASSIGN(ExecutionResult execution_result,
2836+
ExecuteReplicated(std::move(unoptimized_module)));
28442837

2838+
const HloModule* module = execution_result.optimized_module;
28452839
const HloInstruction* root = module->entry_computation()->root_instruction();
28462840
EXPECT_THAT(root, GmockMatch(m::Fusion(
28472841
m::Bitcast(m::AllGatherDone().WithShape(S8, {2, 4})))));
28482842
EXPECT_THAT(root->fused_expression_root(),
28492843
GmockMatch(m::Transpose(m::Parameter())));
28502844

2851-
TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> result,
2852-
ExecuteReplicated(executable.get(), kNumReplicas));
2853-
28542845
const Literal expected_result =
28552846
LiteralUtil::CreateR2<s4>({{s4(0), s4(1), s4(0), s4(1)},
28562847
{s4(2), s4(3), s4(2), s4(3)},
28572848
{s4(4), s4(5), s4(4), s4(5)},
28582849
{s4(5), s4(4), s4(5), s4(4)}});
28592850

2851+
const std::vector<Literal>& result = execution_result.results;
28602852
ASSERT_EQ(result.size(), kNumReplicas);
28612853
for (int i = 0; i < kNumReplicas; ++i) {
28622854
EXPECT_TRUE(LiteralTestUtil::Equal(expected_result, result[i]))

xla/tests/collective_ops_e2e_test_base.cc

Lines changed: 10 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -116,67 +116,6 @@ CollectiveOpsE2ETestBase::CollectiveOpsE2ETestBase() {
116116
reference_platform, /*intra_op_parallelism_threads=*/0);
117117
}
118118

119-
absl::StatusOr<std::vector<Literal>>
120-
CollectiveOpsE2ETestBase::ExecuteReplicated(
121-
absl::AnyInvocable<OpaqueExecutable*(int64_t)> executable_provider,
122-
absl::AnyInvocable<int64_t(int64_t)> argument_count_provider,
123-
absl::AnyInvocable<const Literal*(int64_t, int64_t)> argument_provider,
124-
const int64_t num_replicas, const bool run_hlo_passes,
125-
DeviceAssignment* const device_assignment) {
126-
// TODO(b/441865120): Use designated initializers this once XLA moves to
127-
// C++20.
128-
HloRunnerInterface::ReplicatedExecuteOptions options;
129-
options.num_replicas = num_replicas;
130-
options.run_hlo_passes = run_hlo_passes;
131-
options.use_threads = true;
132-
133-
return hlo_runner_->ExecuteReplicated(
134-
std::move(executable_provider), std::move(argument_count_provider),
135-
std::move(argument_provider), std::move(options), device_assignment);
136-
}
137-
138-
absl::StatusOr<std::vector<Literal>>
139-
CollectiveOpsE2ETestBase::ExecuteReplicated(
140-
std::unique_ptr<HloModule> module,
141-
const absl::Span<const Literal* const> arguments,
142-
const int64_t num_replicas, DeviceAssignment* const device_assignment,
143-
const bool run_hlo_passes, const bool use_threads) {
144-
// TODO(b/441865120): Use designated initializers this once XLA moves to
145-
// C++20.
146-
HloRunnerInterface::ReplicatedExecuteOptions options;
147-
options.num_replicas = num_replicas;
148-
options.arguments = {arguments.begin(), arguments.end()};
149-
options.run_hlo_passes = run_hlo_passes;
150-
options.use_threads = use_threads;
151-
152-
return hlo_runner_->ExecuteReplicated(std::move(module), std::move(options),
153-
device_assignment);
154-
}
155-
156-
absl::StatusOr<std::vector<Literal>>
157-
CollectiveOpsE2ETestBase::ExecuteReplicated(
158-
std::unique_ptr<HloModule> module,
159-
const std::vector<std::vector<Literal*>> arguments,
160-
DeviceAssignment* const device_assignment, const int64_t num_replicas,
161-
const bool run_hlo_passes) {
162-
CHECK(num_replicas > 0 && "expect at least one replica");
163-
CHECK(num_replicas == arguments.size() &&
164-
"expect arguments for each replica");
165-
int64_t argument_count = arguments.front().size();
166-
TF_ASSIGN_OR_RETURN(
167-
const std::unique_ptr<OpaqueExecutable> executable,
168-
hlo_runner_->CreateExecutable(std::move(module), run_hlo_passes));
169-
return ExecuteReplicated(
170-
/*executable_provider=*/[&](int64_t) { return executable.get(); },
171-
/*argument_count_provider=*/[&](int64_t) { return argument_count; },
172-
/*argument_provider=*/
173-
[&](int64_t replica_idx, int64_t argument_idx) -> const Literal* {
174-
return arguments[replica_idx][argument_idx];
175-
},
176-
num_replicas, /*run_hlo_passes=*/run_hlo_passes,
177-
/*device_assignment=*/device_assignment);
178-
}
179-
180119
absl::StatusOr<CollectiveOpsE2ETestBase::ExecutionResult>
181120
CollectiveOpsE2ETestBase::ExecuteReplicated(std::unique_ptr<HloModule> module) {
182121
return ExecuteReplicated(std::move(module),
@@ -224,9 +163,16 @@ CollectiveOpsE2ETestBase::ExecuteReplicated(
224163
execution_result.optimized_module,
225164
hlo_runner_->HloModuleFromWrapped(execution_result.executable.get()));
226165

166+
// TODO(b/441865120): Use designated initializers this once XLA moves to
167+
// C++20.
168+
HloRunnerInterface::ReplicatedExecuteOptions options;
169+
options.num_replicas = num_devices;
170+
options.run_hlo_passes = run_hlo_passes;
171+
options.use_threads = true;
172+
227173
TF_ASSIGN_OR_RETURN(
228174
execution_result.results,
229-
ExecuteReplicated(
175+
hlo_runner_->ExecuteReplicated(
230176
/*executable_provider=*/
231177
[&](int64_t) { return execution_result.executable.get(); },
232178
/*argument_count_provider=*/
@@ -235,21 +181,10 @@ CollectiveOpsE2ETestBase::ExecuteReplicated(
235181
[&](int64_t replica_idx, int64_t argument_idx) -> const Literal* {
236182
return arguments[replica_idx][argument_idx];
237183
},
238-
/*num_replicas=*/num_devices,
239-
/*run_hlo_passes=*/run_hlo_passes,
184+
std::move(options),
240185
/*device_assignment=*/&device_assignment));
241-
return execution_result;
242-
}
243186

244-
absl::StatusOr<std::vector<Literal>>
245-
CollectiveOpsE2ETestBase::ExecuteReplicated(OpaqueExecutable* executable,
246-
int64_t num_replicas) {
247-
DeviceAssignment device_assignment = MakeDeviceAssignment(num_replicas);
248-
return ExecuteReplicated(
249-
/*executable_provider*/ [&](int64_t) { return executable; },
250-
/*argument_count_provider*/ [](int64_t) { return 0; },
251-
/*argument_provider*/ [](int64_t, int64_t) { return nullptr; },
252-
num_replicas, /*run_hlo_passes=*/false, &device_assignment);
187+
return execution_result;
253188
}
254189

255190
DebugOptions CollectiveOpsWithFlagsBase::GetDebugOptionsForTest() const {

xla/tests/collective_ops_e2e_test_base.h

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -54,27 +54,6 @@ class CollectiveOpsE2ETestBase : public HloHardwareIndependentTestBase {
5454
const HloModule* optimized_module;
5555
};
5656

57-
absl::StatusOr<std::vector<Literal>> ExecuteReplicated(
58-
absl::AnyInvocable<OpaqueExecutable*(int64_t)> executable_provider,
59-
absl::AnyInvocable<int64_t(int64_t)> argument_count_provider,
60-
absl::AnyInvocable<const Literal*(int64_t, int64_t)> argument_provider,
61-
int64_t num_replicas, bool run_hlo_passes,
62-
DeviceAssignment* device_assignment);
63-
64-
absl::StatusOr<std::vector<Literal>> ExecuteReplicated(
65-
std::unique_ptr<HloModule> module,
66-
absl::Span<const Literal* const> arguments, int64_t num_replicas,
67-
DeviceAssignment* vice_assignment, bool run_hlo_passes, bool use_threads);
68-
69-
absl::StatusOr<std::vector<Literal>> ExecuteReplicated(
70-
std::unique_ptr<HloModule> module,
71-
std::vector<std::vector<Literal*>> arguments,
72-
DeviceAssignment* device_assignment, int64_t num_replicas,
73-
bool run_hlo_passes);
74-
75-
absl::StatusOr<std::vector<Literal>> ExecuteReplicated(
76-
OpaqueExecutable* executable, int64_t num_replicas);
77-
7857
absl::StatusOr<ExecutionResult> ExecuteReplicated(
7958
std::unique_ptr<HloModule> module);
8059

xla/tests/collective_ops_sharded_unsharded_e2e_test.cc

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,15 @@ class CollectiveOpsTestE2EShardedUnsharded : public CollectiveOpsE2ETestBase {
6161
<< " available)";
6262
}
6363

64-
TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> ref_results,
64+
TF_ASSERT_OK_AND_ASSIGN(ExecutionResult ref_execution_result,
6565
ExecuteUnsharded(hlo_text));
66+
const std::vector<Literal>& ref_results = ref_execution_result.results;
6667
ASSERT_EQ(ref_results.size(), 1);
6768

6869
TF_ASSERT_OK_AND_ASSIGN(
69-
std::vector<Literal> results,
70+
ExecutionResult execution_result,
7071
ExecuteSharded(hlo_text, num_partitions, enable_enzyme_comms_opt));
72+
const std::vector<Literal>& results = execution_result.results;
7173
ASSERT_EQ(results.size(), num_partitions);
7274

7375
ErrorSpec error_spec{1e-4, 1e-4};
@@ -77,7 +79,7 @@ class CollectiveOpsTestE2EShardedUnsharded : public CollectiveOpsE2ETestBase {
7779

7880
private:
7981
// Execute the unsharded case.
80-
absl::StatusOr<std::vector<Literal>> ExecuteUnsharded(
82+
absl::StatusOr<ExecutionResult> ExecuteUnsharded(
8183
const std::string& hlo_text) {
8284
// Create the unsharded reference case by removing the sharding metadata
8385
// from the HLO string.
@@ -90,9 +92,12 @@ class CollectiveOpsTestE2EShardedUnsharded : public CollectiveOpsE2ETestBase {
9092
DebugOptions ref_opts = GetDebugOptionsForTest();
9193
ref_opts.set_xla_gpu_enable_triton_gemm(false);
9294
ref_config.set_debug_options(ref_opts);
93-
ref_config.set_num_partitions(1);
9495
TF_ASSIGN_OR_RETURN(std::unique_ptr<VerifiedHloModule> ref_module,
9596
ParseAndReturnVerifiedModule(hlo_text_ref, ref_config));
97+
98+
ref_module->mutable_config().set_replica_count(1);
99+
ref_module->mutable_config().set_num_partitions(1);
100+
96101
const int64_t num_params =
97102
ref_module->entry_computation()->num_parameters();
98103

@@ -102,17 +107,11 @@ class CollectiveOpsTestE2EShardedUnsharded : public CollectiveOpsE2ETestBase {
102107
ref_fake_ptrs[i] = &fake_args[i];
103108
}
104109

105-
DeviceAssignment ref_assn(/*replica_count=*/1,
106-
/*computation_count=*/1);
107-
ref_assn(0, 0) = 0;
108-
return ExecuteReplicated(std::move(ref_module), ref_fake_ptrs,
109-
/*num_replicas=*/1, &ref_assn,
110-
/*run_hlo_passes=*/true,
111-
/*use_threads=*/true);
110+
return ExecuteReplicated(std::move(ref_module), ref_fake_ptrs);
112111
}
113112

114113
// Execute the sharded case.
115-
absl::StatusOr<std::vector<Literal>> ExecuteSharded(
114+
absl::StatusOr<ExecutionResult> ExecuteSharded(
116115
const std::string& hlo_text, int64_t num_partitions,
117116
bool enable_enzyme_comms_opt = false) {
118117
HloModuleConfig config = GetModuleConfigForTest();
@@ -182,21 +181,14 @@ class CollectiveOpsTestE2EShardedUnsharded : public CollectiveOpsE2ETestBase {
182181
}
183182
}
184183

185-
DeviceAssignment assn(/*replica_count=*/1,
186-
/*computation_count=*/num_partitions);
187-
for (int64_t i = 0; i < num_partitions; ++i) {
188-
assn(0, i) = i;
189-
}
190-
return ExecuteReplicated(std::move(module), fake_ptrs, &assn,
191-
num_partitions,
192-
/*run_hlo_passes=*/true);
184+
return ExecuteReplicated(std::move(module), fake_ptrs);
193185
}
194186

195187
// Slice the unsharded reference results and compare to the sharded case.
196188
void CompareShardedUnsharded(const std::string& hlo_text,
197189
int64_t num_partitions,
198-
std::vector<Literal>& ref_results,
199-
std::vector<Literal>& results,
190+
const std::vector<Literal>& ref_results,
191+
const std::vector<Literal>& results,
200192
ErrorSpec& error_spec) {
201193
HloModuleConfig config = GetModuleConfigForTest();
202194
DebugOptions opts = GetDebugOptionsForTest();

0 commit comments

Comments
 (0)