Skip to content

Commit 5bd3de6

Browse files
khasanovaaGoogle-ML-Automation
authored andcommitted
Add proto [de]serialization for SelectKThunk.
PiperOrigin-RevId: 820210212
1 parent 5df263c commit 5bd3de6

File tree

5 files changed

+76
-15
lines changed

5 files changed

+76
-15
lines changed

xla/backends/gpu/runtime/BUILD

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1136,12 +1136,15 @@ cc_library(
11361136
"//xla/stream_executor:device_memory",
11371137
"//xla/stream_executor:device_memory_allocator",
11381138
"//xla/stream_executor:stream",
1139+
"//xla/tsl/platform:statusor",
11391140
"@com_google_absl//absl/container:inlined_vector",
11401141
"@com_google_absl//absl/log",
11411142
"@com_google_absl//absl/log:check",
11421143
"@com_google_absl//absl/status",
11431144
"@com_google_absl//absl/status:statusor",
11441145
"@com_google_absl//absl/strings",
1146+
"@com_google_absl//absl/types:span",
1147+
"@tsl//tsl/platform:statusor",
11451148
] + if_cuda_is_configured(
11461149
[":select_k_exec_raft"],
11471150
no_cuda = [":select_k_exec_stub"],
@@ -1163,6 +1166,7 @@ xla_cc_test(
11631166
"//xla/service:buffer_assignment",
11641167
"//xla/tsl/platform:statusor",
11651168
"//xla/tsl/util/proto:proto_matchers",
1169+
"@com_google_absl//absl/status:status_matchers",
11661170
"@com_google_googletest//:gtest_main",
11671171
],
11681172
)

xla/backends/gpu/runtime/select_k_thunk.cc

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,18 @@ limitations under the License.
1616
#include "xla/backends/gpu/runtime/select_k_thunk.h"
1717

1818
#include <cstdint>
19+
#include <memory>
1920
#include <string>
21+
#include <utility>
22+
#include <vector>
2023

2124
#include "absl/container/inlined_vector.h"
2225
#include "absl/log/check.h"
2326
#include "absl/log/log.h"
2427
#include "absl/status/status.h"
2528
#include "absl/status/statusor.h"
2629
#include "absl/strings/str_cat.h"
30+
#include "absl/types/span.h"
2731
#include "xla/backends/gpu/runtime/select_k_exec.h"
2832
#include "xla/backends/gpu/runtime/thunk.h"
2933
#include "xla/backends/gpu/runtime/thunk.pb.h"
@@ -32,9 +36,11 @@ limitations under the License.
3236
#include "xla/hlo/ir/hlo_instruction.h"
3337
#include "xla/primitive_util.h"
3438
#include "xla/service/buffer_assignment.h"
39+
#include "xla/shape.h"
3540
#include "xla/stream_executor/device_memory.h"
3641
#include "xla/stream_executor/device_memory_allocator.h"
3742
#include "xla/stream_executor/stream.h"
43+
#include "xla/tsl/platform/statusor.h"
3844
#include "xla/types.h"
3945

4046
namespace xla::gpu {
@@ -105,10 +111,35 @@ absl::Status SelectKThunk::ExecuteOnStream(const ExecuteParams& params) {
105111
absl::StatusOr<ThunkProto> SelectKThunk::ToProto() const {
106112
ThunkProto proto;
107113
*proto.mutable_thunk_info() = thunk_info().ToProto();
114+
SelectKThunkProto* select_k_proto = proto.mutable_select_k_thunk();
108115

109-
SelectKThunkProto* select_k_thunk_proto = proto.mutable_select_k_thunk();
110-
(void)select_k_thunk_proto;
111-
// TODO(upwind): Add fields for SelectKThunkProto.
116+
select_k_proto->set_batch_size(batch_size_);
117+
select_k_proto->set_num_elements(num_elements_);
118+
select_k_proto->set_k(k_);
119+
select_k_proto->set_dtype(dtype_);
120+
121+
for (const BufferAllocation::Slice& arg : args_) {
122+
TF_ASSIGN_OR_RETURN(*select_k_proto->add_args(), arg.ToProto());
123+
}
112124
return proto;
113125
}
126+
127+
absl::StatusOr<std::unique_ptr<SelectKThunk>> SelectKThunk::FromProto(
128+
ThunkInfo thunk_info, const SelectKThunkProto& proto,
129+
absl::Span<const BufferAllocation> buffer_allocations) {
130+
std::vector<emitters::KernelArgument> arguments;
131+
arguments.reserve(proto.args().size());
132+
for (const xla::buffer_assignment::BufferAllocationSliceProto& arg :
133+
proto.args()) {
134+
TF_ASSIGN_OR_RETURN(
135+
BufferAllocation::Slice slice,
136+
BufferAllocation::Slice::FromProto(arg, buffer_allocations));
137+
emitters::KernelArgument argument{Shape{}, slice};
138+
arguments.push_back(std::move(argument));
139+
}
140+
return std::make_unique<SelectKThunk>(
141+
thunk_info, proto.batch_size(), proto.num_elements(), proto.k(),
142+
proto.dtype(), emitters::KernelArguments(std::move(arguments)));
143+
}
144+
114145
} // namespace xla::gpu

xla/backends/gpu/runtime/select_k_thunk.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717
#define XLA_BACKENDS_GPU_RUNTIME_SELECT_K_THUNK_H_
1818

1919
#include <cstdint>
20+
#include <memory>
2021
#include <string>
2122
#include <vector>
2223

@@ -63,6 +64,10 @@ class SelectKThunk : public Thunk {
6364

6465
absl::StatusOr<ThunkProto> ToProto() const override;
6566

67+
static absl::StatusOr<std::unique_ptr<SelectKThunk>> FromProto(
68+
ThunkInfo thunk_info, const SelectKThunkProto& proto,
69+
absl::Span<const BufferAllocation> buffer_allocations);
70+
6671
private:
6772
std::uint32_t batch_size_;
6873
std::uint32_t num_elements_;

xla/backends/gpu/runtime/select_k_thunk_test.cc

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ limitations under the License.
2121

2222
#include <gmock/gmock.h>
2323
#include <gtest/gtest.h>
24+
#include "absl/status/status_matchers.h"
2425
#include "xla/backends/gpu/runtime/thunk.h"
2526
#include "xla/backends/gpu/runtime/thunk.pb.h"
2627
#include "xla/backends/gpu/runtime/thunk_id.h"
@@ -35,6 +36,7 @@ limitations under the License.
3536
namespace xla::gpu {
3637
namespace {
3738

39+
using ::absl_testing::IsOkAndHolds;
3840
using ::tsl::proto_testing::EqualsProto;
3941

4042
TEST(SelectKThunkTest, ToProto) {
@@ -46,30 +48,45 @@ TEST(SelectKThunkTest, ToProto) {
4648
Thunk::ThunkInfo thunk_info =
4749
Thunk::ThunkInfo::WithProfileAnnotation(topKInst.get(), ThunkId{456});
4850

49-
BufferAllocation alloc0(/*index=*/0, /*size=*/20, /*color=*/0);
50-
BufferAllocation::Slice slice0(&alloc0, /*offset=*/0, /*size=*/20);
51+
std::vector<BufferAllocation> buffer_allocations = {
52+
{/*index=*/0, /*size=*/20, /*color=*/0},
53+
{/*index=*/1, /*size=*/12, /*color=*/0},
54+
{/*index=*/2, /*size=*/12, /*color=*/0}};
5155

52-
BufferAllocation alloc1(/*index=*/1, /*size=*/12, /*color=*/0);
53-
BufferAllocation::Slice slice1(&alloc1, /*offset=*/0, /*size=*/12);
54-
55-
BufferAllocation alloc2(/*index=*/2, /*size=*/12, /*color=*/0);
56-
BufferAllocation::Slice slice2(&alloc2, /*offset=*/0, /*size=*/12);
56+
BufferAllocation::Slice slice0(&buffer_allocations[0], /*offset=*/0,
57+
/*size=*/20);
58+
BufferAllocation::Slice slice1(&buffer_allocations[1], /*offset=*/0,
59+
/*size=*/12);
60+
BufferAllocation::Slice slice2(&buffer_allocations[2], /*offset=*/0,
61+
/*size=*/12);
5762

5863
emitters::KernelArgument arg0(ShapeUtil::MakeShape(F32, {1, 5}), slice0);
5964
emitters::KernelArgument arg1(ShapeUtil::MakeShape(F32, {1, 3}), slice1);
6065
emitters::KernelArgument arg2(ShapeUtil::MakeShape(U32, {1, 3}), slice2);
61-
arg0.set_written(false);
62-
arg1.set_written(true);
63-
arg2.set_written(true);
6466

6567
emitters::KernelArguments kernel_arguments({arg0, arg1, arg2});
6668

6769
SelectKThunk thunk(std::move(thunk_info), 1, 5, 3, F32, kernel_arguments);
70+
6871
TF_ASSERT_OK_AND_ASSIGN(ThunkProto proto, thunk.ToProto());
6972
EXPECT_THAT(proto, EqualsProto(R"pb(
7073
thunk_info { profile_annotation: "custom-call" thunk_id: 456 }
71-
select_k_thunk {}
74+
select_k_thunk {
75+
args { buffer_allocation_index: 0 size: 20 }
76+
args { buffer_allocation_index: 1 size: 12 }
77+
args { buffer_allocation_index: 2 size: 12 }
78+
batch_size: 1
79+
num_elements: 5
80+
k: 3
81+
dtype: F32
82+
}
7283
)pb"));
84+
85+
TF_ASSERT_OK_AND_ASSIGN(
86+
std::unique_ptr<SelectKThunk> deserialized,
87+
SelectKThunk::FromProto(thunk.thunk_info(), proto.select_k_thunk(),
88+
buffer_allocations));
89+
EXPECT_THAT(deserialized->ToProto(), IsOkAndHolds(EqualsProto(proto)));
7390
}
7491

7592
} // namespace

xla/backends/gpu/runtime/thunk.proto

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,11 @@ message OutfeedThunkProto {
178178
}
179179

180180
message SelectKThunkProto {
181-
// TODO(upwind): Add fields for SelectKThunkProto.
181+
repeated xla.buffer_assignment.BufferAllocationSliceProto args = 1;
182+
uint32 batch_size = 3;
183+
uint32 num_elements = 4;
184+
uint32 k = 5;
185+
xla.PrimitiveType dtype = 6;
182186
}
183187

184188
message CublasLtMatmulThunkProto {

0 commit comments

Comments
 (0)