Skip to content

Commit f4a264c

Browse files
Move CollectiveMetadataTest to a separate file.
PiperOrigin-RevId: 839275803
1 parent b170c49 commit f4a264c

File tree

3 files changed

+204
-145
lines changed

3 files changed

+204
-145
lines changed

xla/tests/BUILD

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2960,6 +2960,33 @@ xla_test(
29602960
],
29612961
)
29622962

2963+
xla_test(
2964+
name = "collective_metadata_test",
2965+
srcs = ["collective_metadata_test.cc"],
2966+
backend_tags = {
2967+
"gpu": [
2968+
"multi_gpu",
2969+
],
2970+
"nvgpu_any": [
2971+
"broken",
2972+
"no_oss",
2973+
],
2974+
},
2975+
backends = [
2976+
"gpu",
2977+
],
2978+
deps = [
2979+
":collective_ops_e2e_test_base",
2980+
":xla_internal_test_main",
2981+
"//xla:literal",
2982+
"//xla:literal_util",
2983+
"//xla/tsl/platform:statusor",
2984+
"//xla/tsl/platform:test",
2985+
"@com_google_absl//absl/strings:string_view",
2986+
"@com_google_absl//absl/types:span",
2987+
],
2988+
)
2989+
29632990
xla_test(
29642991
name = "collective_ops_sharded_unsharded_e2e_test",
29652992
srcs = ["collective_ops_sharded_unsharded_e2e_test.cc"],
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
/* Copyright 2025 The OpenXLA Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
#include <array>
16+
#include <cstdint>
17+
#include <memory>
18+
#include <utility>
19+
#include <vector>
20+
21+
#include "absl/strings/string_view.h"
22+
#include "absl/types/span.h"
23+
#include "xla/literal.h"
24+
#include "xla/literal_util.h"
25+
#include "xla/tests/collective_ops_e2e_test_base.h"
26+
#include "xla/tsl/platform/statusor.h"
27+
#include "xla/tsl/platform/test.h"
28+
29+
namespace xla {
30+
namespace {
31+
32+
class CollectiveMetadataTest : public CollectiveOpsE2ETestBase {
33+
protected:
34+
void SetUp() override {
35+
CollectiveOpsE2ETestBase::SetUp();
36+
if (!IsHopperAndHigher()) {
37+
GTEST_SKIP() << "Test requires Hopper or newer architecture since it's "
38+
"using a multicast.";
39+
}
40+
}
41+
};
42+
43+
TEST_F(CollectiveMetadataTest, ConstructCollectiveMetadata) {
44+
const absl::string_view kModuleStr = R"(
45+
HloModule test, replica_count=2
46+
47+
ENTRY test_computation {
48+
param_0 = f32[4] parameter(0)
49+
param_1 = f32[4] parameter(1)
50+
copy_1 = f32[4]{0:S(1)} copy(param_1)
51+
52+
const_0 = f32[1] constant({10})
53+
54+
result_tuple = (f32[4], f32[4]{0:S(1)}, f32[1], u64[9]) custom-call(param_0, copy_1, const_0), custom_call_target="CollectiveMetadata", output_to_operand_aliasing={{0}: (0, {}), {1}: (1, {})}
55+
ROOT get_tuple_element = u64[9] get-tuple-element(result_tuple), index=3
56+
})";
57+
58+
constexpr int kNumReplicas = 2;
59+
ASSERT_GE(hlo_runner_->device_count(), kNumReplicas)
60+
<< "Test requires at least " << kNumReplicas << " devices ("
61+
<< hlo_runner_->device_count() << " available)";
62+
63+
TF_ASSERT_OK_AND_ASSIGN(
64+
auto unoptimized_module,
65+
ParseAndReturnVerifiedModule(kModuleStr, kNumReplicas));
66+
67+
Literal input_0 = LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f, 4.0f});
68+
Literal input_1 = LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f, 4.0f});
69+
TF_ASSERT_OK_AND_ASSIGN(
70+
ExecutionResult execution_result,
71+
ExecuteReplicated(std::move(unoptimized_module),
72+
/*arguments=*/std::vector<Literal*>{&input_0, &input_1},
73+
/*run_hlo_passes=*/false));
74+
const std::vector<Literal>& result = execution_result.results;
75+
ASSERT_EQ(result.size(), kNumReplicas);
76+
77+
absl::Span<const uint64_t> first_result_data = result[0].data<uint64_t>();
78+
absl::Span<const uint64_t> second_result_data = result[1].data<uint64_t>();
79+
constexpr int kNumElements = 9;
80+
ASSERT_EQ(first_result_data.size(), kNumElements);
81+
ASSERT_EQ(second_result_data.size(), kNumElements);
82+
83+
// Check the rank in the first position.
84+
EXPECT_EQ(first_result_data[0], 0);
85+
EXPECT_EQ(second_result_data[0], 1);
86+
87+
// Check pointer to peers in the second position.
88+
EXPECT_NE(first_result_data[1], 0);
89+
EXPECT_NE(second_result_data[1], 0);
90+
91+
// Check pointer to multimem metadata in the third position.
92+
EXPECT_NE(first_result_data[2], 0);
93+
EXPECT_NE(second_result_data[2], 0);
94+
95+
// Check param_to_peers structure.
96+
for (int i = 3; i < kNumElements; ++i) {
97+
EXPECT_NE(first_result_data[i], 0);
98+
EXPECT_EQ(second_result_data[i], first_result_data[i]);
99+
}
100+
}
101+
102+
TEST_F(CollectiveMetadataTest, ConstructCollectiveMetadataWithReplicaGroup) {
103+
const absl::string_view kModuleStr = R"(
104+
HloModule test, replica_count=4
105+
106+
ENTRY test_computation {
107+
param_0 = f32[4] parameter(0)
108+
param_1 = f32[4] parameter(1)
109+
copy_1 = f32[4]{0:S(1)} copy(param_1)
110+
111+
result_tuple = (f32[4], f32[4]{0:S(1)}, u64[7]) custom-call(param_0, copy_1), custom_call_target="CollectiveMetadata", output_to_operand_aliasing={{0}: (0, {}), {1}: (1, {})}, backend_config="{\"collective_metadata_backend_config\":{\"collective_devices\": { \"replica_groups\": [{\"replica_ids\": [0,1]}, {\"replica_ids\": [2,3]}]}}}"
112+
ROOT get_tuple_element = u64[7] get-tuple-element(result_tuple), index=2
113+
})";
114+
115+
constexpr int kNumReplicas = 4;
116+
if (hlo_runner_->device_count() < kNumReplicas) {
117+
GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices ("
118+
<< hlo_runner_->device_count() << " available)";
119+
}
120+
121+
TF_ASSERT_OK_AND_ASSIGN(
122+
auto module, ParseAndReturnVerifiedModule(kModuleStr, kNumReplicas));
123+
124+
Literal input_0 = LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f, 4.0f});
125+
Literal input_1 = LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f, 4.0f});
126+
127+
TF_ASSERT_OK_AND_ASSIGN(
128+
ExecutionResult execution_result,
129+
ExecuteReplicated(std::move(module),
130+
/*arguments=*/std::vector<Literal*>{&input_0, &input_1},
131+
/*run_hlo_passes=*/false));
132+
const std::vector<Literal>& result = execution_result.results;
133+
ASSERT_EQ(result.size(), kNumReplicas);
134+
absl::Span<const uint64_t> replica_0_result_0_data =
135+
result[0].data<uint64_t>();
136+
absl::Span<const uint64_t> replica_0_result_1_data =
137+
result[1].data<uint64_t>();
138+
absl::Span<const uint64_t> replica_1_result_0_data =
139+
result[2].data<uint64_t>();
140+
absl::Span<const uint64_t> replica_1_result_1_data =
141+
result[3].data<uint64_t>();
142+
143+
// Check the rank in the first position.
144+
constexpr int kNumElements = 7;
145+
ASSERT_EQ(replica_0_result_0_data.size(), kNumElements);
146+
ASSERT_EQ(replica_0_result_1_data.size(), kNumElements);
147+
ASSERT_EQ(replica_1_result_0_data.size(), kNumElements);
148+
ASSERT_EQ(replica_1_result_1_data.size(), kNumElements);
149+
150+
EXPECT_EQ(replica_0_result_0_data[0], 0);
151+
EXPECT_EQ(replica_0_result_1_data[0], 1);
152+
EXPECT_EQ(replica_1_result_0_data[0], 0);
153+
EXPECT_EQ(replica_1_result_1_data[0], 1);
154+
155+
// Check pointer to peers in the second position.
156+
EXPECT_NE(replica_0_result_0_data[1], 0);
157+
EXPECT_NE(replica_0_result_1_data[1], 0);
158+
EXPECT_NE(replica_1_result_0_data[1], 0);
159+
EXPECT_NE(replica_1_result_1_data[1], 0);
160+
161+
// Check pointer to multimem metadata in the third position.
162+
EXPECT_NE(replica_0_result_0_data[2], 0);
163+
EXPECT_NE(replica_0_result_1_data[2], 0);
164+
EXPECT_NE(replica_1_result_0_data[2], 0);
165+
EXPECT_NE(replica_1_result_1_data[2], 0);
166+
167+
// Check param_to_peers structure.
168+
for (int i = 3; i < kNumElements; ++i) {
169+
EXPECT_NE(replica_0_result_0_data[i], 0);
170+
EXPECT_EQ(replica_0_result_1_data[i], replica_0_result_0_data[i]);
171+
EXPECT_NE(replica_1_result_0_data[i], 0);
172+
EXPECT_EQ(replica_1_result_1_data[i], replica_1_result_0_data[i]);
173+
}
174+
}
175+
176+
} // namespace
177+
} // namespace xla

xla/tests/collective_ops_e2e_test.cc

Lines changed: 0 additions & 145 deletions
Original file line numberDiff line numberDiff line change
@@ -2871,150 +2871,5 @@ INSTANTIATE_TEST_SUITE_P(
28712871
return absl::StrCat(GetAsyncTestName(std::get<0>(info.param)), "_",
28722872
std::get<1>(info.param) ? "one_shot" : "nccl");
28732873
});
2874-
2875-
class CollectiveMetadataTest : public CollectiveOpsE2ETestBase {
2876-
protected:
2877-
void SetUp() override {
2878-
CollectiveOpsE2ETestBase::SetUp();
2879-
if (!IsHopperAndHigher()) {
2880-
GTEST_SKIP() << "Test requires Hopper or newer architecture since it's "
2881-
"using a multicast.";
2882-
}
2883-
}
2884-
};
2885-
2886-
TEST_F(CollectiveMetadataTest, ConstructCollectiveMetadata) {
2887-
const absl::string_view kModuleStr = R"(
2888-
HloModule test, replica_count=2
2889-
2890-
ENTRY test_computation {
2891-
param_0 = f32[4] parameter(0)
2892-
param_1 = f32[4] parameter(1)
2893-
copy_1 = f32[4]{0:S(1)} copy(param_1)
2894-
2895-
const_0 = f32[1] constant({10})
2896-
2897-
result_tuple = (f32[4], f32[4]{0:S(1)}, f32[1], u64[9]) custom-call(param_0, copy_1, const_0), custom_call_target="CollectiveMetadata", output_to_operand_aliasing={{0}: (0, {}), {1}: (1, {})}
2898-
ROOT get_tuple_element = u64[9] get-tuple-element(result_tuple), index=3
2899-
})";
2900-
2901-
constexpr int kNumReplicas = 2;
2902-
ASSERT_GE(hlo_runner_->device_count(), kNumReplicas)
2903-
<< "Test requires at least " << kNumReplicas << " devices ("
2904-
<< hlo_runner_->device_count() << " available)";
2905-
2906-
TF_ASSERT_OK_AND_ASSIGN(
2907-
auto unoptimized_module,
2908-
ParseAndReturnVerifiedModule(kModuleStr, kNumReplicas));
2909-
2910-
Literal input_0 = LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f, 4.0f});
2911-
Literal input_1 = LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f, 4.0f});
2912-
TF_ASSERT_OK_AND_ASSIGN(
2913-
ExecutionResult execution_result,
2914-
ExecuteReplicated(std::move(unoptimized_module),
2915-
/*arguments=*/std::vector<Literal*>{&input_0, &input_1},
2916-
/*run_hlo_passes=*/false));
2917-
const std::vector<Literal>& result = execution_result.results;
2918-
ASSERT_EQ(result.size(), kNumReplicas);
2919-
2920-
absl::Span<const uint64_t> first_result_data = result[0].data<uint64_t>();
2921-
absl::Span<const uint64_t> second_result_data = result[1].data<uint64_t>();
2922-
constexpr int kNumElements = 9;
2923-
ASSERT_EQ(first_result_data.size(), kNumElements);
2924-
ASSERT_EQ(second_result_data.size(), kNumElements);
2925-
2926-
// Check the rank in the first position.
2927-
EXPECT_EQ(first_result_data[0], 0);
2928-
EXPECT_EQ(second_result_data[0], 1);
2929-
2930-
// Check pointer to peers in the second position.
2931-
EXPECT_NE(first_result_data[1], 0);
2932-
EXPECT_NE(second_result_data[1], 0);
2933-
2934-
// Check pointer to multimem metadata in the third position.
2935-
EXPECT_NE(first_result_data[2], 0);
2936-
EXPECT_NE(second_result_data[2], 0);
2937-
2938-
// Check param_to_peers structure.
2939-
for (int i = 3; i < kNumElements; ++i) {
2940-
EXPECT_NE(first_result_data[i], 0);
2941-
EXPECT_EQ(second_result_data[i], first_result_data[i]);
2942-
}
2943-
}
2944-
2945-
TEST_F(CollectiveMetadataTest, ConstructCollectiveMetadataWithReplicaGroup) {
2946-
const absl::string_view kModuleStr = R"(
2947-
HloModule test, replica_count=4
2948-
2949-
ENTRY test_computation {
2950-
param_0 = f32[4] parameter(0)
2951-
param_1 = f32[4] parameter(1)
2952-
copy_1 = f32[4]{0:S(1)} copy(param_1)
2953-
2954-
result_tuple = (f32[4], f32[4]{0:S(1)}, u64[7]) custom-call(param_0, copy_1), custom_call_target="CollectiveMetadata", output_to_operand_aliasing={{0}: (0, {}), {1}: (1, {})}, backend_config="{\"collective_metadata_backend_config\":{\"collective_devices\": { \"replica_groups\": [{\"replica_ids\": [0,1]}, {\"replica_ids\": [2,3]}]}}}"
2955-
ROOT get_tuple_element = u64[7] get-tuple-element(result_tuple), index=2
2956-
})";
2957-
2958-
constexpr int kNumReplicas = 4;
2959-
if (hlo_runner_->device_count() < kNumReplicas) {
2960-
GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices ("
2961-
<< hlo_runner_->device_count() << " available)";
2962-
}
2963-
2964-
TF_ASSERT_OK_AND_ASSIGN(
2965-
auto module, ParseAndReturnVerifiedModule(kModuleStr, kNumReplicas));
2966-
2967-
Literal input_0 = LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f, 4.0f});
2968-
Literal input_1 = LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f, 4.0f});
2969-
2970-
TF_ASSERT_OK_AND_ASSIGN(
2971-
ExecutionResult execution_result,
2972-
ExecuteReplicated(std::move(module),
2973-
/*arguments=*/std::vector<Literal*>{&input_0, &input_1},
2974-
/*run_hlo_passes=*/false));
2975-
const std::vector<Literal>& result = execution_result.results;
2976-
ASSERT_EQ(result.size(), kNumReplicas);
2977-
absl::Span<const uint64_t> replica_0_result_0_data =
2978-
result[0].data<uint64_t>();
2979-
absl::Span<const uint64_t> replica_0_result_1_data =
2980-
result[1].data<uint64_t>();
2981-
absl::Span<const uint64_t> replica_1_result_0_data =
2982-
result[2].data<uint64_t>();
2983-
absl::Span<const uint64_t> replica_1_result_1_data =
2984-
result[3].data<uint64_t>();
2985-
2986-
// Check the rank in the first position.
2987-
constexpr int kNumElements = 7;
2988-
ASSERT_EQ(replica_0_result_0_data.size(), kNumElements);
2989-
ASSERT_EQ(replica_0_result_1_data.size(), kNumElements);
2990-
ASSERT_EQ(replica_1_result_0_data.size(), kNumElements);
2991-
ASSERT_EQ(replica_1_result_1_data.size(), kNumElements);
2992-
2993-
EXPECT_EQ(replica_0_result_0_data[0], 0);
2994-
EXPECT_EQ(replica_0_result_1_data[0], 1);
2995-
EXPECT_EQ(replica_1_result_0_data[0], 0);
2996-
EXPECT_EQ(replica_1_result_1_data[0], 1);
2997-
2998-
// Check pointer to peers in the second position.
2999-
EXPECT_NE(replica_0_result_0_data[1], 0);
3000-
EXPECT_NE(replica_0_result_1_data[1], 0);
3001-
EXPECT_NE(replica_1_result_0_data[1], 0);
3002-
EXPECT_NE(replica_1_result_1_data[1], 0);
3003-
3004-
// Check pointer to multimem metadata in the third position.
3005-
EXPECT_NE(replica_0_result_0_data[2], 0);
3006-
EXPECT_NE(replica_0_result_1_data[2], 0);
3007-
EXPECT_NE(replica_1_result_0_data[2], 0);
3008-
EXPECT_NE(replica_1_result_1_data[2], 0);
3009-
3010-
// Check param_to_peers structure.
3011-
for (int i = 3; i < kNumElements; ++i) {
3012-
EXPECT_NE(replica_0_result_0_data[i], 0);
3013-
EXPECT_EQ(replica_0_result_1_data[i], replica_0_result_0_data[i]);
3014-
EXPECT_NE(replica_1_result_0_data[i], 0);
3015-
EXPECT_EQ(replica_1_result_1_data[i], replica_1_result_0_data[i]);
3016-
}
3017-
}
3018-
30192874
} // namespace
30202875
} // namespace xla

0 commit comments

Comments
 (0)