Skip to content

Commit 1209728

Browse files
Reverts 56bc4f8
PiperOrigin-RevId: 820748684
1 parent 3555e03 commit 1209728

File tree

8 files changed

+184
-15
lines changed

8 files changed

+184
-15
lines changed

xla/backends/cpu/buffer_allocation_info.h

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@ limitations under the License.
2222
#include "absl/strings/str_format.h"
2323
#include "absl/strings/string_view.h"
2424

25-
namespace xla {
26-
namespace cpu {
25+
namespace xla::cpu {
2726

2827
// `BufferAllocationInfo` stores information about buffer allocations required
2928
// by an XLA:CPU executable at run time. It corresponds to a `BufferAllocation`
@@ -194,14 +193,6 @@ class BufferAllocationInfo {
194193
int32_t result_number_ = -1;
195194
};
196195

197-
} // namespace cpu
198-
199-
// TODO(ezhulenev): This is a temporary hack to keep `tfcompile` code working.
200-
namespace cpu_function_runtime {
201-
using BufferInfo = ::xla::cpu::BufferAllocationInfo;
202-
using EncodedBufferInfo = ::xla::cpu::BufferAllocationInfo::Encoded;
203-
} // namespace cpu_function_runtime
204-
205-
} // namespace xla
196+
} // namespace xla::cpu
206197

207198
#endif // XLA_BACKENDS_CPU_BUFFER_ALLOCATION_INFO_H_

xla/cpu_function_runtime.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ limitations under the License.
2222
#include <cstdlib>
2323

2424
namespace xla {
25-
namespace cpu_function_runtime_deprecated {
25+
namespace cpu_function_runtime {
2626

2727
struct EncodedBufferInfo {
2828
uint64_t packed_kind_and_size = 0;
@@ -174,7 +174,7 @@ class BufferInfo {
174174
int32_t result_param_number_ = -1;
175175
};
176176

177-
} // namespace cpu_function_runtime_deprecated
177+
} // namespace cpu_function_runtime
178178
} // namespace xla
179179

180180
#endif // XLA_CPU_FUNCTION_RUNTIME_H_

xla/service/cpu/BUILD

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,12 +170,25 @@ cc_library(
170170
alwayslink = True, # Contains per-platform transfer manager registration
171171
)
172172

173+
cc_library(
174+
name = "buffer_info_util",
175+
srcs = ["buffer_info_util.cc"],
176+
hdrs = ["buffer_info_util.h"],
177+
deps = [
178+
"//xla:cpu_function_runtime",
179+
"//xla/hlo/ir:hlo",
180+
"//xla/service:buffer_assignment",
181+
"@com_google_absl//absl/types:span",
182+
],
183+
)
184+
173185
cc_library(
174186
name = "cpu_compiler_pure",
175187
srcs = ["cpu_compiler.cc"],
176188
hdrs = ["cpu_compiler.h"],
177189
copts = tsl_copts(),
178190
deps = [
191+
":buffer_info_util",
179192
":conv_canonicalization",
180193
":cpu_aot_compilation_result",
181194
":cpu_aot_loader",
@@ -414,6 +427,7 @@ cc_library(
414427
srcs = ["cpu_aot_compilation_result.cc"],
415428
hdrs = ["cpu_aot_compilation_result.h"],
416429
deps = [
430+
":buffer_info_util",
417431
":cpu_executable",
418432
":executable_proto_cc",
419433
"//xla:cpu_function_runtime",
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
/* Copyright 2018 The OpenXLA Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "xla/service/cpu/buffer_info_util.h"
17+
18+
#include <cassert>
19+
#include <cstdint>
20+
#include <vector>
21+
22+
#include "absl/types/span.h"
23+
#include "xla/cpu_function_runtime.h"
24+
25+
namespace xla {
26+
namespace cpu {
27+
28+
using BufferInfo = cpu_function_runtime::BufferInfo;
29+
30+
std::vector<BufferInfo> CreateBufferInfosFromBufferAssignment(
31+
const HloModule& module, const BufferAssignment& buffer_assignment) {
32+
std::vector<BufferInfo> buffer_infos;
33+
for (const BufferAllocation& allocation : buffer_assignment.Allocations()) {
34+
if (allocation.is_thread_local()) {
35+
buffer_infos.push_back(BufferInfo::MakeOnStackBuffer(allocation.size()));
36+
} else if (allocation.is_constant()) {
37+
buffer_infos.push_back(BufferInfo::MakeConstant(allocation.size()));
38+
} else if (allocation.is_entry_computation_parameter()) {
39+
buffer_infos.push_back(BufferInfo::MakeEntryParameter(
40+
/*size=*/allocation.size(),
41+
/*param_number=*/allocation.parameter_number()));
42+
} else {
43+
buffer_infos.push_back(BufferInfo::MakeTempBuffer(allocation.size()));
44+
}
45+
}
46+
47+
// Fill in the result parameters' indices, expanding all tuples.
48+
auto root_instr = module.entry_computation()->root_instruction();
49+
auto output_allocation = buffer_assignment.GetUniqueTopLevelOutputSlice();
50+
if (output_allocation->allocation()->is_tuple()) {
51+
int out_index = 0;
52+
ShapeUtil::ForEachSubshape(
53+
root_instr->shape(),
54+
[&](const Shape& subshape, const ShapeIndex& index) {
55+
if (subshape.IsTuple()) {
56+
return;
57+
}
58+
int64_t result_index =
59+
buffer_assignment.GetUniqueSlice(root_instr, index)->index();
60+
assert(result_index < buffer_infos.size());
61+
buffer_infos[result_index].set_result_parameter_number(out_index++);
62+
});
63+
}
64+
65+
return buffer_infos;
66+
}
67+
68+
std::vector<int32_t> CreateArgIndexTableFromBufferInfos(
69+
absl::Span<const BufferInfo> buffer_infos) {
70+
std::vector<int32_t> ret;
71+
for (int64_t i = 0; i < buffer_infos.size(); i++) {
72+
if (!buffer_infos[i].is_entry_parameter()) {
73+
continue;
74+
}
75+
uint64_t param_index = buffer_infos[i].entry_parameter_number();
76+
if (param_index >= ret.size()) {
77+
ret.resize(param_index + 1);
78+
}
79+
ret[param_index] = i;
80+
}
81+
return ret;
82+
}
83+
84+
std::vector<int32_t> CreateResultIndexTableFromBufferInfos(
85+
absl::Span<const BufferInfo> buffer_infos) {
86+
std::vector<int32_t> ret;
87+
for (int64_t i = 0; i < buffer_infos.size(); i++) {
88+
if (!buffer_infos[i].is_result_parameter()) {
89+
continue;
90+
}
91+
uint64_t result_index = buffer_infos[i].result_parameter_number();
92+
if (result_index >= ret.size()) {
93+
ret.resize(result_index + 1);
94+
}
95+
ret[result_index] = i;
96+
}
97+
return ret;
98+
}
99+
100+
} // namespace cpu
101+
} // namespace xla

xla/service/cpu/buffer_info_util.h

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/* Copyright 2018 The OpenXLA Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#ifndef XLA_SERVICE_CPU_BUFFER_INFO_UTIL_H_
17+
#define XLA_SERVICE_CPU_BUFFER_INFO_UTIL_H_
18+
19+
#include <cstdint>
20+
#include <vector>
21+
22+
#include "absl/types/span.h"
23+
#include "xla/cpu_function_runtime.h"
24+
#include "xla/hlo/ir/hlo_module.h"
25+
#include "xla/service/buffer_assignment.h"
26+
27+
namespace xla {
28+
namespace cpu {
29+
// Creates and returns a list of BufferInfo instances containing relevant
30+
// information from `buffer_assignment`.
31+
std::vector<cpu_function_runtime::BufferInfo>
32+
CreateBufferInfosFromBufferAssignment(
33+
const HloModule& module, const BufferAssignment& buffer_assignment);
34+
35+
// Creates and returns a table containing the mapping from entry computation
36+
// parameters to buffer allocation indices.
37+
//
38+
// If this function returns V then entry parameter i has buffer allocation index
39+
// V[i].
40+
std::vector<int32_t> CreateArgIndexTableFromBufferInfos(
41+
absl::Span<const cpu_function_runtime::BufferInfo> buffer_infos);
42+
43+
std::vector<int32_t> CreateResultIndexTableFromBufferInfos(
44+
absl::Span<const cpu_function_runtime::BufferInfo> buffer_infos);
45+
} // namespace cpu
46+
} // namespace xla
47+
48+
#endif // XLA_SERVICE_CPU_BUFFER_INFO_UTIL_H_

xla/service/cpu/cpu_aot_compilation_result.cc

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ limitations under the License.
4343
#include "xla/service/buffer_assignment.h"
4444
#include "xla/service/buffer_value.h"
4545
#include "xla/service/compiler.h"
46+
#include "xla/service/cpu/buffer_info_util.h"
4647
#include "xla/service/cpu/cpu_executable.h"
4748
#include "xla/service/cpu/executable.pb.h"
4849
#include "xla/service/executable.h"
@@ -56,6 +57,7 @@ limitations under the License.
5657
#include "xla/util.h"
5758

5859
namespace xla::cpu {
60+
using BufferInfo = cpu_function_runtime::BufferInfo;
5961

6062
CpuAotCompilationOptions::CpuAotCompilationOptions(
6163
std::string triple, std::string cpu_name, std::string features,
@@ -86,10 +88,14 @@ CpuAotCompilationResult::Create(
8688
TF_ASSIGN_OR_RETURN(ThunkSequenceProto thunk_proto,
8789
thunk_sequence_serdes.ToProto(thunks));
8890

91+
std::vector<cpu_function_runtime::BufferInfo> buffer_infos;
8992
std::vector<cpu::BufferAllocationInfo> buffer_allocation_infos;
9093
std::optional<size_t> temp_allocation_index;
9194

9295
if (buffer_assignment) {
96+
buffer_infos =
97+
CreateBufferInfosFromBufferAssignment(*hlo_module, *buffer_assignment);
98+
9399
buffer_allocation_infos =
94100
CreateBufferAllocationInfos(*hlo_module, *buffer_assignment);
95101

@@ -108,19 +114,21 @@ CpuAotCompilationResult::Create(
108114
return absl::WrapUnique(new CpuAotCompilationResult(
109115
hlo_module, buffer_assignment, function_name, std::move(obj_files),
110116
std::move(symbols), thunk_proto, std::move(temp_allocation_index),
111-
std::move(buffer_allocation_infos), std::move(function_library),
112-
std::move(hlo_profile_printer_data)));
117+
std::move(buffer_infos), std::move(buffer_allocation_infos),
118+
std::move(function_library), std::move(hlo_profile_printer_data)));
113119
}
114120

115121
CpuAotCompilationResult::CpuAotCompilationResult(
116122
const HloModule* hlo_module, const BufferAssignment* buffer_assignment,
117123
absl::string_view function_name, std::vector<ObjFileProto> obj_files,
118124
std::vector<SymbolProto> symbols, const ThunkSequenceProto& thunks,
119125
std::optional<size_t> temp_allocation_index,
126+
std::vector<cpu_function_runtime::BufferInfo> buffer_infos,
120127
std::vector<BufferAllocationInfo> buffer_allocation_infos,
121128
std::unique_ptr<FunctionLibrary> function_library,
122129
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data)
123130
: temp_allocation_index_(temp_allocation_index),
131+
buffer_infos_(std::move(buffer_infos)),
124132
buffer_allocation_infos_(std::move(buffer_allocation_infos)),
125133
function_library_(std::move(function_library)),
126134
hlo_profile_printer_data_(std::move(hlo_profile_printer_data)) {

xla/service/cpu/cpu_aot_compilation_result.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,10 @@ class CpuAotCompilationResult : public AotCompilationResult {
148148
return temp_allocation_index_;
149149
}
150150

151+
const std::vector<cpu_function_runtime::BufferInfo>& buffer_infos() const {
152+
return buffer_infos_;
153+
}
154+
151155
absl::Span<const BufferAllocationInfo> buffer_allocation_infos() const {
152156
return buffer_allocation_infos_;
153157
}
@@ -184,6 +188,7 @@ class CpuAotCompilationResult : public AotCompilationResult {
184188
absl::string_view function_name, std::vector<ObjFileProto> obj_files,
185189
std::vector<SymbolProto> symbols, const ThunkSequenceProto& thunks,
186190
std::optional<size_t> temp_allocation_index,
191+
std::vector<cpu_function_runtime::BufferInfo> buffer_infos,
187192
std::vector<BufferAllocationInfo> buffer_allocation_infos,
188193
std::unique_ptr<FunctionLibrary> function_library,
189194
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data);
@@ -198,6 +203,7 @@ class CpuAotCompilationResult : public AotCompilationResult {
198203
CompilationResultProto proto_;
199204
std::unique_ptr<HloModule> module_;
200205
std::optional<size_t> temp_allocation_index_;
206+
std::vector<cpu_function_runtime::BufferInfo> buffer_infos_;
201207
std::vector<BufferAllocationInfo> buffer_allocation_infos_;
202208

203209
std::unique_ptr<FunctionLibrary> function_library_;

xla/service/cpu/cpu_compiler.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ limitations under the License.
172172
#include "xla/service/conditional_simplifier.h"
173173
#include "xla/service/conditional_to_select.h"
174174
#include "xla/service/copy_insertion.h"
175+
#include "xla/service/cpu/buffer_info_util.h"
175176
#include "xla/service/cpu/conv_canonicalization.h"
176177
#include "xla/service/cpu/cpu_aot_compilation_result.h"
177178
#include "xla/service/cpu/cpu_aot_loader.h"

0 commit comments

Comments
 (0)