|
| 1 | +/* Copyright 2025 The OpenXLA Authors. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. |
| 14 | +==============================================================================*/ |
| 15 | +#include <array> |
| 16 | +#include <cstdint> |
| 17 | +#include <memory> |
| 18 | +#include <utility> |
| 19 | +#include <vector> |
| 20 | + |
| 21 | +#include "absl/strings/string_view.h" |
| 22 | +#include "absl/types/span.h" |
| 23 | +#include "xla/literal.h" |
| 24 | +#include "xla/literal_util.h" |
| 25 | +#include "xla/tests/collective_ops_e2e_test_base.h" |
| 26 | +#include "xla/tsl/platform/statusor.h" |
| 27 | +#include "xla/tsl/platform/test.h" |
| 28 | + |
| 29 | +namespace xla { |
| 30 | +namespace { |
| 31 | + |
| 32 | +class CollectiveMetadataTest : public CollectiveOpsE2ETestBase { |
| 33 | + protected: |
| 34 | + void SetUp() override { |
| 35 | + CollectiveOpsE2ETestBase::SetUp(); |
| 36 | + if (!IsHopperAndHigher()) { |
| 37 | + GTEST_SKIP() << "Test requires Hopper or newer architecture since it's " |
| 38 | + "using a multicast."; |
| 39 | + } |
| 40 | + } |
| 41 | +}; |
| 42 | + |
| 43 | +TEST_F(CollectiveMetadataTest, ConstructCollectiveMetadata) { |
| 44 | + const absl::string_view kModuleStr = R"( |
| 45 | + HloModule test, replica_count=2 |
| 46 | +
|
| 47 | + ENTRY test_computation { |
| 48 | + param_0 = f32[4] parameter(0) |
| 49 | + param_1 = f32[4] parameter(1) |
| 50 | + copy_1 = f32[4]{0:S(1)} copy(param_1) |
| 51 | +
|
| 52 | + const_0 = f32[1] constant({10}) |
| 53 | +
|
| 54 | + result_tuple = (f32[4], f32[4]{0:S(1)}, f32[1], u64[9]) custom-call(param_0, copy_1, const_0), custom_call_target="CollectiveMetadata", output_to_operand_aliasing={{0}: (0, {}), {1}: (1, {})} |
| 55 | + ROOT get_tuple_element = u64[9] get-tuple-element(result_tuple), index=3 |
| 56 | + })"; |
| 57 | + |
| 58 | + constexpr int kNumReplicas = 2; |
| 59 | + ASSERT_GE(hlo_runner_->device_count(), kNumReplicas) |
| 60 | + << "Test requires at least " << kNumReplicas << " devices (" |
| 61 | + << hlo_runner_->device_count() << " available)"; |
| 62 | + |
| 63 | + TF_ASSERT_OK_AND_ASSIGN( |
| 64 | + auto unoptimized_module, |
| 65 | + ParseAndReturnVerifiedModule(kModuleStr, kNumReplicas)); |
| 66 | + |
| 67 | + Literal input_0 = LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f, 4.0f}); |
| 68 | + Literal input_1 = LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f, 4.0f}); |
| 69 | + TF_ASSERT_OK_AND_ASSIGN( |
| 70 | + ExecutionResult execution_result, |
| 71 | + ExecuteReplicated(std::move(unoptimized_module), |
| 72 | + /*arguments=*/std::vector<Literal*>{&input_0, &input_1}, |
| 73 | + /*run_hlo_passes=*/false)); |
| 74 | + const std::vector<Literal>& result = execution_result.results; |
| 75 | + ASSERT_EQ(result.size(), kNumReplicas); |
| 76 | + |
| 77 | + absl::Span<const uint64_t> first_result_data = result[0].data<uint64_t>(); |
| 78 | + absl::Span<const uint64_t> second_result_data = result[1].data<uint64_t>(); |
| 79 | + constexpr int kNumElements = 9; |
| 80 | + ASSERT_EQ(first_result_data.size(), kNumElements); |
| 81 | + ASSERT_EQ(second_result_data.size(), kNumElements); |
| 82 | + |
| 83 | + // Check the rank in the first position. |
| 84 | + EXPECT_EQ(first_result_data[0], 0); |
| 85 | + EXPECT_EQ(second_result_data[0], 1); |
| 86 | + |
| 87 | + // Check pointer to peers in the second position. |
| 88 | + EXPECT_NE(first_result_data[1], 0); |
| 89 | + EXPECT_NE(second_result_data[1], 0); |
| 90 | + |
| 91 | + // Check pointer to multimem metadata in the third position. |
| 92 | + EXPECT_NE(first_result_data[2], 0); |
| 93 | + EXPECT_NE(second_result_data[2], 0); |
| 94 | + |
| 95 | + // Check param_to_peers structure. |
| 96 | + for (int i = 3; i < kNumElements; ++i) { |
| 97 | + EXPECT_NE(first_result_data[i], 0); |
| 98 | + EXPECT_EQ(second_result_data[i], first_result_data[i]); |
| 99 | + } |
| 100 | +} |
| 101 | + |
| 102 | +TEST_F(CollectiveMetadataTest, ConstructCollectiveMetadataWithReplicaGroup) { |
| 103 | + const absl::string_view kModuleStr = R"( |
| 104 | + HloModule test, replica_count=4 |
| 105 | +
|
| 106 | + ENTRY test_computation { |
| 107 | + param_0 = f32[4] parameter(0) |
| 108 | + param_1 = f32[4] parameter(1) |
| 109 | + copy_1 = f32[4]{0:S(1)} copy(param_1) |
| 110 | +
|
| 111 | + result_tuple = (f32[4], f32[4]{0:S(1)}, u64[7]) custom-call(param_0, copy_1), custom_call_target="CollectiveMetadata", output_to_operand_aliasing={{0}: (0, {}), {1}: (1, {})}, backend_config="{\"collective_metadata_backend_config\":{\"collective_devices\": { \"replica_groups\": [{\"replica_ids\": [0,1]}, {\"replica_ids\": [2,3]}]}}}" |
| 112 | + ROOT get_tuple_element = u64[7] get-tuple-element(result_tuple), index=2 |
| 113 | + })"; |
| 114 | + |
| 115 | + constexpr int kNumReplicas = 4; |
| 116 | + if (hlo_runner_->device_count() < kNumReplicas) { |
| 117 | + GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices (" |
| 118 | + << hlo_runner_->device_count() << " available)"; |
| 119 | + } |
| 120 | + |
| 121 | + TF_ASSERT_OK_AND_ASSIGN( |
| 122 | + auto module, ParseAndReturnVerifiedModule(kModuleStr, kNumReplicas)); |
| 123 | + |
| 124 | + Literal input_0 = LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f, 4.0f}); |
| 125 | + Literal input_1 = LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f, 4.0f}); |
| 126 | + |
| 127 | + TF_ASSERT_OK_AND_ASSIGN( |
| 128 | + ExecutionResult execution_result, |
| 129 | + ExecuteReplicated(std::move(module), |
| 130 | + /*arguments=*/std::vector<Literal*>{&input_0, &input_1}, |
| 131 | + /*run_hlo_passes=*/false)); |
| 132 | + const std::vector<Literal>& result = execution_result.results; |
| 133 | + ASSERT_EQ(result.size(), kNumReplicas); |
| 134 | + absl::Span<const uint64_t> replica_0_result_0_data = |
| 135 | + result[0].data<uint64_t>(); |
| 136 | + absl::Span<const uint64_t> replica_0_result_1_data = |
| 137 | + result[1].data<uint64_t>(); |
| 138 | + absl::Span<const uint64_t> replica_1_result_0_data = |
| 139 | + result[2].data<uint64_t>(); |
| 140 | + absl::Span<const uint64_t> replica_1_result_1_data = |
| 141 | + result[3].data<uint64_t>(); |
| 142 | + |
| 143 | + // Check the rank in the first position. |
| 144 | + constexpr int kNumElements = 7; |
| 145 | + ASSERT_EQ(replica_0_result_0_data.size(), kNumElements); |
| 146 | + ASSERT_EQ(replica_0_result_1_data.size(), kNumElements); |
| 147 | + ASSERT_EQ(replica_1_result_0_data.size(), kNumElements); |
| 148 | + ASSERT_EQ(replica_1_result_1_data.size(), kNumElements); |
| 149 | + |
| 150 | + EXPECT_EQ(replica_0_result_0_data[0], 0); |
| 151 | + EXPECT_EQ(replica_0_result_1_data[0], 1); |
| 152 | + EXPECT_EQ(replica_1_result_0_data[0], 0); |
| 153 | + EXPECT_EQ(replica_1_result_1_data[0], 1); |
| 154 | + |
| 155 | + // Check pointer to peers in the second position. |
| 156 | + EXPECT_NE(replica_0_result_0_data[1], 0); |
| 157 | + EXPECT_NE(replica_0_result_1_data[1], 0); |
| 158 | + EXPECT_NE(replica_1_result_0_data[1], 0); |
| 159 | + EXPECT_NE(replica_1_result_1_data[1], 0); |
| 160 | + |
| 161 | + // Check pointer to multimem metadata in the third position. |
| 162 | + EXPECT_NE(replica_0_result_0_data[2], 0); |
| 163 | + EXPECT_NE(replica_0_result_1_data[2], 0); |
| 164 | + EXPECT_NE(replica_1_result_0_data[2], 0); |
| 165 | + EXPECT_NE(replica_1_result_1_data[2], 0); |
| 166 | + |
| 167 | + // Check param_to_peers structure. |
| 168 | + for (int i = 3; i < kNumElements; ++i) { |
| 169 | + EXPECT_NE(replica_0_result_0_data[i], 0); |
| 170 | + EXPECT_EQ(replica_0_result_1_data[i], replica_0_result_0_data[i]); |
| 171 | + EXPECT_NE(replica_1_result_0_data[i], 0); |
| 172 | + EXPECT_EQ(replica_1_result_1_data[i], replica_1_result_0_data[i]); |
| 173 | + } |
| 174 | +} |
| 175 | + |
| 176 | +} // namespace |
| 177 | +} // namespace xla |
0 commit comments