Skip to content

Commit 7d6b101

Browse files
[XLA:GPU] Pass peer pointers for each kernel parameter.
PiperOrigin-RevId: 828345912
1 parent 18b1bae commit 7d6b101

File tree

8 files changed

+225
-90
lines changed

8 files changed

+225
-90
lines changed

xla/backends/gpu/codegen/triton/ir/triton_xla_ops.td

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -401,22 +401,26 @@ def TTXLA_GetRankOp : TTXLA_Op<"get_rank", [Pure]> {
401401
def TTXLA_GetPeerPtrOp : TTXLA_Op<"get_peer_ptr", [Pure]> {
402402
let summary = [{
403403
Extract the pointer to the given symmetric memory `address` on the given
404-
`peer` device using the symmetric memory `metadata`.
405-
For this an operation first calculates an offset of the `address` to the
406-
current rank symmetric memory range, and the adds this offset to the
407-
symmetric memory range of the `peer` device.
404+
`peer` device. An `address` should point to the memory of the given kernel
405+
argument with `argument_index`. The result is calculated using the symmetric
406+
memory `metadata` constructed at the runtime.
407+
To calculate offsets operation also need to know the number of devices
408+
participating in the collective operation (`world_size`).
408409
}];
409410
let arguments = (ins
410411
Arg<TT_PtrLike, "",
411412
[MemRead<GlobalMemory>]>:$address,
412413
I64:$peer_id,
413414
Arg<TT_PtrLike, "",
414-
[MemRead<GlobalMemory>]>:$metadata);
415+
[MemRead<GlobalMemory>]>:$metadata,
416+
I32Attr:$argument_index,
417+
// The number of devices participating in the collective operation.
418+
I32Attr:$world_size);
415419

416420
let results = (outs Arg<TT_PtrLike, "", [MemRead<GlobalMemory>]>:$result);
417421

418422
let assemblyFormat = [{
419-
$address `,` $peer_id `,` $metadata attr-dict `:`
423+
$address `,` $peer_id `,` $metadata `,` attr-dict `:`
420424
functional-type(operands, results)
421425
}];
422426
}

xla/backends/gpu/codegen/triton/transforms/tests/triton_xla_remote_access.mlir

Lines changed: 59 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,39 +12,86 @@ tt.func @get_rank(
1212
}
1313

1414
tt.func @get_peer_ptr(
15-
%arg0: !tt.ptr<i64>, %peer_id: i64, %metadata: !tt.ptr<i64>
15+
%arg0: !tt.ptr<i64>, %arg1: !tt.ptr<i64>, %peer_id: i64, %metadata: !tt.ptr<i64>
1616
) -> !tt.ptr<i64> {
1717
// CHECK-NOT: triton_xla.get_peer_ptr
18-
// Byte size of a pointer.
18+
// An offset from the beginning of metadata to the peer pointers for the %arg1
19+
// offset(param_to_peers) + sizeof(uint64_t) * 2 = 20
20+
// CHECK: %c24_i64 = arith.constant 24 : i64
21+
// Size of the uint64_t.
1922
// CHECK: %c8_i64 = arith.constant 8 : i64
2023

2124
// Load metadata->rank
22-
// CHECK-NEXT: %0 = tt.load %arg2 : !tt.ptr<i64>
25+
// CHECK-NEXT: %0 = tt.load %arg3 : !tt.ptr<i64>
2326

2427
// Calculate offset to current base pointer.
2528
// CHECK-NEXT: %1 = arith.muli %0, %c8_i64 : i64
2629

27-
// Load metadata->buffer_root_ptrs[metadata->rank].
30+
// Load metadata->param_to_peers[argument_offset + metadata->rank].
31+
// Here argument_offset = 0 since %arg0 is the first argument.
2832
// CHECK-NEXT: %2 = arith.addi %1, %c8_i64 : i64
29-
// CHECK-NEXT: %3 = tt.addptr %arg2, %2 : !tt.ptr<i64>, i64
33+
// CHECK-NEXT: %3 = tt.addptr %arg3, %2 : !tt.ptr<i64>, i64
3034
// CHECK-NEXT: %4 = tt.load %3 : !tt.ptr<i64>
3135

3236
// Calculate offset to address.
3337
// CHECK-NEXT: %5 = tt.ptr_to_int %arg0 : !tt.ptr<i64> -> i64
3438
// CHECK-NEXT: %6 = arith.subi %5, %4 : i64
3539

3640
// Calculate offset to peer base pointer.
37-
// CHECK-NEXT: %7 = arith.muli %arg1, %c8_i64 : i64
41+
// CHECK-NEXT: %7 = arith.muli %arg2, %c8_i64 : i64
3842
// CHECK-NEXT: %8 = arith.addi %7, %c8_i64 : i64
3943

40-
// Load metadata->buffer_root_ptrs[peer_id].
41-
// CHECK-NEXT: %9 = tt.addptr %arg2, %8 : !tt.ptr<i64>, i64
44+
// Load metadata->peer_base_ptrs[argument_offset + peer_id].
45+
// CHECK-NEXT: %9 = tt.addptr %arg3, %8 : !tt.ptr<i64>, i64
4246
// CHECK-NEXT: %10 = tt.load %9 : !tt.ptr<i64>
4347

44-
// Load metadata->buffer_root_ptrs[peer_id] + offset.
48+
// Load metadata->buffer_root_ptrs[argument_offset + peer_id] + offset.
4549
// CHECK-NEXT: %11 = arith.addi %10, %6 : i64
4650
// CHECK-NEXT: %12 = tt.int_to_ptr %11 : i64 -> !tt.ptr<i64>
47-
// CHECK-NEXT: tt.return %12 : !tt.ptr<i64>
48-
%peer_ptr = triton_xla.get_peer_ptr %arg0, %peer_id, %metadata : (!tt.ptr<i64>, i64, !tt.ptr<i64>) -> !tt.ptr<i64>
49-
tt.return %peer_ptr : !tt.ptr<i64>
51+
%arg_0_peer_ptr = triton_xla.get_peer_ptr %arg0, %peer_id, %metadata,
52+
{ argument_index = 0 : i32, world_size = 2 : i32 } :
53+
(!tt.ptr<i64>, i64, !tt.ptr<i64>) -> !tt.ptr<i64>
54+
55+
// Load metadata->rank
56+
// CHECK-NEXT: %13 = tt.load %arg3 : !tt.ptr<i64>
57+
// Calculate offset to current base pointer.
58+
// CHECK-NEXT: %14 = arith.muli %13, %c8_i64 : i64
59+
// Load metadata->param_to_peers[argument_offset + metadata->rank].
60+
// CHECK-NEXT: %15 = arith.addi %14, %c24_i64 : i64
61+
// CHECK-NEXT: %16 = tt.addptr %arg3, %15 : !tt.ptr<i64>, i64
62+
// CHECK-NEXT: %17 = tt.load %16 : !tt.ptr<i64>
63+
// Calculate offset to address.
64+
// CHECK-NEXT: %18 = tt.ptr_to_int %arg1 : !tt.ptr<i64> -> i64
65+
// CHECK-NEXT: %19 = arith.subi %18, %17 : i64
66+
67+
// Calculate offset to peer base pointer.
68+
// CHECK-NEXT: %20 = arith.muli %arg2, %c8_i64 : i64
69+
// CHECK-NEXT: %21 = arith.addi %20, %c24_i64 : i64
70+
71+
// Load metadata->peer_base_ptrs[argument_offset + peer_id].
72+
// CHECK-NEXT: %22 = tt.addptr %arg3, %21 : !tt.ptr<i64>, i64
73+
// CHECK-NEXT: %23 = tt.load %22 : !tt.ptr<i64>
74+
75+
// Load metadata->buffer_root_ptrs[argument_offset + peer_id] + offset.
76+
// CHECK-NEXT: %24 = arith.addi %23, %19 : i64
77+
// CHECK-NEXT: %25 = tt.int_to_ptr %24 : i64 -> !tt.ptr<i64>
78+
79+
%arg_1_peer_ptr = triton_xla.get_peer_ptr %arg1, %peer_id, %metadata,
80+
{ argument_index = 1 : i32, world_size = 2 : i32 } :
81+
(!tt.ptr<i64>, i64, !tt.ptr<i64>) -> !tt.ptr<i64>
82+
83+
// Avoid optimizing away the get_peer_ptr calls, by returning xor of the two
84+
// peer pointers.
85+
//
86+
// CHECK-NEXT: %26 = tt.ptr_to_int %12 : !tt.ptr<i64> -> i64
87+
%int_arg0 = tt.ptr_to_int %arg_0_peer_ptr : !tt.ptr<i64> -> i64
88+
// CHECK-NEXT: %27 = tt.ptr_to_int %25 : !tt.ptr<i64> -> i64
89+
%int_arg1 = tt.ptr_to_int %arg_1_peer_ptr : !tt.ptr<i64> -> i64
90+
91+
// CHECK-NEXT: %28 = arith.ori %26, %27 : i64
92+
%result_int = arith.ori %int_arg0, %int_arg1 : i64
93+
// CHECK-NEXT: %29 = tt.int_to_ptr %28 : i64 -> !tt.ptr<i64>
94+
%result_ptr = tt.int_to_ptr %result_int : i64 -> !tt.ptr<i64>
95+
// CHECK-NEXT: tt.return %29 : !tt.ptr<i64>
96+
tt.return %result_ptr : !tt.ptr<i64>
5097
}

xla/backends/gpu/codegen/triton/transforms/triton_xla_lower_remote_access_pass.cc

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,14 @@ LogicalResult LowerGetRankOp(GetRankOp get_rank, PatternRewriter& rewriter) {
6767

6868
// The peer address should be computed as follows:
6969
//
70-
// offset = address - metadata->buffer_root_ptrs[metadata->rank].
71-
// peer_address = metadata->buffer_root_ptrs[peer_id] + offset.
70+
// argument_offset = world_size * argument_index
71+
// argument_base = metadata->param_to_peers[argument_offset + metadata->rank]
72+
// offset = address - argument_base
73+
// peer_base = metadata->param_to_peers[argument_offset + peer_id]
74+
// peer_address = peer_base + offset
75+
//
76+
// For more details regarding peer pointers layout see comments in the:
77+
// `stream_executor::gpu::CollectiveKernelMetadata`.
7278
LogicalResult LowerGetPeerPtrOp(GetPeerPtrOp get_peer_ptr,
7379
PatternRewriter& rewriter) {
7480
Value metadata = get_peer_ptr.getMetadata();
@@ -94,16 +100,26 @@ LogicalResult LowerGetPeerPtrOp(GetPeerPtrOp get_peer_ptr,
94100
// 1. Load metadata->rank.
95101
Value current_rank_load_op = builder.create<GetRankOp>(metadata);
96102

97-
// 2. Load metadata->buffer_root_ptrs[metadata->rank].
103+
// 2. Calculate argument_offset = num_ranks * argument_index.
104+
const int32_t argument_index = get_peer_ptr.getArgumentIndex();
105+
const int32_t world_size = get_peer_ptr.getWorldSize();
106+
const int32_t argument_offset =
107+
world_size * argument_index * sizeof(uint64_t);
108+
109+
// 3. Load metadata->param_to_peers[argument_offset + metadata->rank].
98110
Value local_buffers_ptrs_offset = builder.create<arith::ConstantIntOp>(
99-
type_i64, offsetof(CollectiveKernelMetadata, buffer_root_ptrs));
111+
type_i64, offsetof(CollectiveKernelMetadata, param_to_peers));
100112

101113
Value rank_offset =
102114
builder.create<arith::ExtUIOp>(type_i64, current_rank_load_op);
115+
Value argument_offset_bytes =
116+
builder.create<arith::ConstantIntOp>(type_i64, argument_offset);
103117
Value current_rank_offset_bytes =
104118
builder.create<arith::MulIOp>(rank_offset, pointer_size_bytes_const);
119+
Value argument_ptr_offset_bytes = builder.create<arith::AddIOp>(
120+
local_buffers_ptrs_offset, argument_offset_bytes);
105121
Value current_ptr_offset_bytes = builder.create<arith::AddIOp>(
106-
local_buffers_ptrs_offset, current_rank_offset_bytes);
122+
argument_ptr_offset_bytes, current_rank_offset_bytes);
107123

108124
Value current_range_address = builder.create<AddPtrOp>(
109125
metadata.getType(), metadata, current_ptr_offset_bytes);
@@ -115,19 +131,19 @@ LogicalResult LowerGetPeerPtrOp(GetPeerPtrOp get_peer_ptr,
115131
EvictionPolicyAttr::get(ctx, EvictionPolicy::NORMAL),
116132
/*isVolatile=*/rewriter.getBoolAttr(false));
117133

118-
// 3. Calculate offset =
119-
// address - metadata->buffer_root_ptrs[metadata->rank].
134+
// 4. Calculate offset =
135+
// address - metadata->param_to_peers[argument_offset + metadata->rank].
120136
Value current_range_address_int =
121137
builder.create<PtrToIntOp>(type_i64, address);
122138
Value offsetInt = builder.create<arith::SubIOp>(current_range_address_int,
123139
current_range_address_value);
124140

125-
// 4. Load metadata->buffer_root_ptrs[peer_id].
141+
// 5. Load metadata->param_to_peers[argument_offset + peer_id].
126142
Value peer_index = builder.create<arith::ExtUIOp>(type_i64, peer_id);
127143
Value peer_index_offset_bytes =
128144
builder.create<arith::MulIOp>(peer_index, pointer_size_bytes_const);
129145
Value peer_range_offset_bytes = builder.create<arith::AddIOp>(
130-
local_buffers_ptrs_offset, peer_index_offset_bytes);
146+
argument_ptr_offset_bytes, peer_index_offset_bytes);
131147
Value peer_range_address = builder.create<AddPtrOp>(
132148
metadata.getType(), metadata, peer_range_offset_bytes);
133149

@@ -138,7 +154,7 @@ LogicalResult LowerGetPeerPtrOp(GetPeerPtrOp get_peer_ptr,
138154
EvictionPolicyAttr::get(ctx, EvictionPolicy::NORMAL),
139155
/*isVolatile=*/rewriter.getBoolAttr(false));
140156

141-
// 5. Calculate the result address: peerBasePtr + offset.
157+
// 6. Calculate the result address: peerBasePtr + offset.
142158
Value result_int =
143159
builder.create<arith::AddIOp>(peer_range_address_value, offsetInt);
144160
Value result_address = builder.create<IntToPtrOp>(result_type, result_int);

xla/backends/gpu/runtime/all_reduce_test.cc

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License.
1616
#include "xla/backends/gpu/runtime/all_reduce.h"
1717

1818
#include <algorithm>
19+
#include <cstddef>
1920
#include <cstdint>
2021
#include <memory>
2122
#include <tuple>
@@ -150,15 +151,22 @@ class AllReduceKernelTest : public ::testing::Test,
150151
}
151152

152153
std::vector<se::DeviceMemoryBase> metadata_buffers;
154+
// One for signal and one for input parameters.
155+
constexpr int kNumPeerParameters = 2;
156+
size_t param_to_peers_size =
157+
sizeof(uint64_t) * kNumPeerParameters * num_ranks;
158+
std::vector<uint64_t> param_to_peers_ptrs;
159+
for (const auto& local_input_buffer : local_input_buffers) {
160+
param_to_peers_ptrs.push_back((uint64_t)local_input_buffer.opaque());
161+
}
162+
for (const auto& signal_flags_buffer : signal_flags_buffers) {
163+
param_to_peers_ptrs.push_back((uint64_t)signal_flags_buffer.opaque());
164+
}
153165

154166
for (int i = 0; i < num_ranks; ++i) {
155167
CollectiveKernelMetadata metadata;
156168
metadata.rank = i;
157169

158-
for (int j = 0; j < num_ranks; ++j) {
159-
metadata.buffer_root_ptrs[j] = (uint64_t)allocated_buffers[j].opaque();
160-
}
161-
162170
if (params_.all_reduce_strategy == AllReduceStrategy::kMultimem) {
163171
stream_executor::gpu::GpuExecutor* gpu_executor =
164172
dynamic_cast<stream_executor::gpu::GpuExecutor*>(executors[i]);
@@ -171,11 +179,21 @@ class AllReduceKernelTest : public ::testing::Test,
171179
metadata.multicast_buffer_ptr = 0;
172180
}
173181

182+
// First map from parameter to peer ptrs and then metadata.
174183
metadata_buffers.emplace_back(executors[i]->AllocateArray<uint64_t>(
175-
sizeof(CollectiveKernelMetadata)));
184+
sizeof(CollectiveKernelMetadata) + param_to_peers_size));
185+
186+
se::DeviceMemoryBase param_to_peers_ptrs_buffer =
187+
metadata_buffers[i].GetByteSlice(sizeof(CollectiveKernelMetadata),
188+
param_to_peers_size);
189+
metadata.param_to_peers =
190+
reinterpret_cast<uint64_t*>(param_to_peers_ptrs_buffer.opaque());
176191

177192
TF_RETURN_IF_ERROR(streams[i]->Memcpy(&metadata_buffers[i], &metadata,
178193
sizeof(CollectiveKernelMetadata)));
194+
TF_RETURN_IF_ERROR(streams[i]->Memcpy(&param_to_peers_ptrs_buffer,
195+
param_to_peers_ptrs.data(),
196+
param_to_peers_size));
179197
}
180198

181199
for (int i = 0; i < num_ranks; ++i) {

0 commit comments

Comments
 (0)