Skip to content

Commit b1ee739

Browse files
beckerheGoogle-ML-Automation
authored andcommitted
Fix lifetime issue in CompilationEnvironmentsTest
`CompilationEnvironments::RegisterProcessNewEnvFn` requires the passed in `Descriptor*` pointer to be valid for the lifetime of the program. The test `GetEnvTriggersFullNameFallback` dynamically registers a processing function for a custom descriptor and achieves the lifetime requirements by leaking the `DescriptorPool` instance (call to operator new). This change fixes that properly by adding a deregister function to `CompilationEnvironments` and deregisters the processing function at the end of the test case. It also removes all usages of `tsl::protobuf` and replaces them by `proto2`. PiperOrigin-RevId: 853274937
1 parent 06490a0 commit b1ee739

File tree

4 files changed

+92
-67
lines changed

4 files changed

+92
-67
lines changed

xla/service/BUILD

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5530,14 +5530,13 @@ cc_library(
55305530
"@com_google_absl//absl/container:flat_hash_map",
55315531
"@com_google_absl//absl/log",
55325532
"@com_google_absl//absl/log:check",
5533-
"@com_google_absl//absl/memory",
55345533
"@com_google_absl//absl/status",
55355534
"@com_google_absl//absl/status:statusor",
55365535
"@com_google_absl//absl/strings",
55375536
"@com_google_absl//absl/synchronization",
55385537
"@com_google_protobuf//:any_cc_proto",
5538+
"@com_google_protobuf//:protobuf",
55395539
"@tsl//tsl/platform:logging", # fixdeps: keep
5540-
"@tsl//tsl/platform:protobuf",
55415540
],
55425541
)
55435542

@@ -5566,16 +5565,18 @@ xla_cc_test(
55665565
deps = [
55675566
":compilation_environments",
55685567
":test_compilation_environment_proto_cc",
5568+
# "@com_google_protobuf//:descriptor_cc_proto" - This target is included in @com_google_protobuf//:protobuf and not available separately in OSS.
5569+
"@com_google_googletest//:gtest",
5570+
"@com_google_absl//absl/cleanup",
5571+
"@com_google_absl//absl/status",
5572+
"@com_google_absl//absl/status:status_matchers",
5573+
"@com_google_protobuf//:protobuf",
55695574
"//xla:xla_proto_cc",
55705575
"//xla/hlo/testlib:test",
55715576
"//xla/tests:xla_internal_test_main",
55725577
"//xla/tsl/lib/core:status_test_util",
55735578
"//xla/tsl/platform:statusor",
5574-
"@com_google_absl//absl/status",
5575-
"@com_google_absl//absl/status:status_matchers",
5576-
"@com_google_googletest//:gtest",
55775579
"@tsl//tsl/platform:casts",
5578-
"@tsl//tsl/platform:protobuf",
55795580
],
55805581
)
55815582

xla/service/compilation_environments.cc

Lines changed: 37 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -28,22 +28,23 @@ limitations under the License.
2828
#include "absl/container/flat_hash_map.h"
2929
#include "absl/log/check.h"
3030
#include "absl/log/log.h"
31-
#include "absl/memory/memory.h"
3231
#include "absl/status/status.h"
3332
#include "absl/status/statusor.h"
3433
#include "absl/strings/str_cat.h"
3534
#include "absl/strings/str_join.h"
3635
#include "absl/strings/string_view.h"
3736
#include "absl/synchronization/mutex.h"
37+
#include "google/protobuf/descriptor.h"
38+
#include "google/protobuf/message.h"
39+
#include "google/protobuf/unknown_field_set.h"
3840
#include "xla/tsl/platform/errors.h"
3941
#include "xla/tsl/platform/statusor.h"
40-
#include "tsl/platform/protobuf.h"
4142

4243
namespace xla {
4344
namespace {
4445

4546
ABSL_CONST_INIT absl::Mutex process_new_env_fns_mu(absl::kConstInit);
46-
absl::flat_hash_map<const tsl::protobuf::Descriptor*,
47+
absl::flat_hash_map<const google::protobuf::Descriptor*,
4748
CompilationEnvironments::ProcessNewEnvFn>*
4849
process_new_env_fns ABSL_GUARDED_BY(process_new_env_fns_mu) = nullptr;
4950

@@ -128,8 +129,8 @@ CompilationEnvironments::CreateFromProto(
128129
const CompilationEnvironmentsProto& proto) {
129130
auto envs = std::make_unique<CompilationEnvironments>();
130131

131-
const tsl::protobuf::DescriptorPool* const pool =
132-
tsl::protobuf::DescriptorPool::generated_pool();
132+
const google::protobuf::DescriptorPool* const pool =
133+
google::protobuf::DescriptorPool::generated_pool();
133134

134135
for (const auto& env_proto : proto.environments()) {
135136
std::string fullname;
@@ -140,22 +141,21 @@ CompilationEnvironments::CreateFromProto(
140141
env_proto.type_url()));
141142
}
142143

143-
const tsl::protobuf::Descriptor* const descriptor =
144+
const google::protobuf::Descriptor* const descriptor =
144145
pool->FindMessageTypeByName(fullname);
145146
if (descriptor == nullptr) {
146147
return absl::DataLossError(absl::StrCat(
147148
"Unknown CompilationEnvironment message type: ", fullname));
148149
}
149150

150-
const tsl::protobuf::Message* const prototype =
151-
tsl::protobuf::MessageFactory::generated_factory()->GetPrototype(
152-
descriptor);
151+
const google::protobuf::Message* const prototype =
152+
google::protobuf::MessageFactory::generated_factory()->GetPrototype(descriptor);
153153
if (prototype == nullptr) {
154154
return absl::InternalError(absl::StrCat(
155155
"Unsupported CompilationEnvironment message type: ", fullname));
156156
}
157157

158-
std::unique_ptr<tsl::protobuf::Message> env(prototype->New());
158+
std::unique_ptr<google::protobuf::Message> env(prototype->New());
159159
if (!env_proto.UnpackTo(env.get())) {
160160
return absl::DataLossError(absl::StrCat(
161161
"Unable to unpack CompilationEnvironment message of type '", fullname,
@@ -169,12 +169,11 @@ CompilationEnvironments::CreateFromProto(
169169
}
170170

171171
void CompilationEnvironments::RegisterProcessNewEnvFn(
172-
const tsl::protobuf::Descriptor* descriptor,
173-
ProcessNewEnvFn process_new_env) {
172+
const google::protobuf::Descriptor* descriptor, ProcessNewEnvFn process_new_env) {
174173
absl::MutexLock l(process_new_env_fns_mu);
175174
if (process_new_env_fns == nullptr) {
176175
process_new_env_fns =
177-
new absl::flat_hash_map<const tsl::protobuf::Descriptor*,
176+
new absl::flat_hash_map<const google::protobuf::Descriptor*,
178177
CompilationEnvironments::ProcessNewEnvFn>();
179178
}
180179
const bool inserted =
@@ -184,8 +183,21 @@ void CompilationEnvironments::RegisterProcessNewEnvFn(
184183
<< descriptor->full_name() << "' has already been registered";
185184
}
186185

186+
void CompilationEnvironments::DeregisterProcessNewEnvFn(
187+
const google::protobuf::Descriptor* descriptor) {
188+
absl::MutexLock l(process_new_env_fns_mu);
189+
if (process_new_env_fns == nullptr) {
190+
return;
191+
}
192+
const auto it = process_new_env_fns->find(descriptor);
193+
if (it == process_new_env_fns->end()) {
194+
return;
195+
}
196+
process_new_env_fns->erase(it);
197+
}
198+
187199
absl::Status CompilationEnvironments::InitializeAllKnownEnvs() {
188-
std::vector<const tsl::protobuf::Descriptor*> descriptors;
200+
std::vector<const google::protobuf::Descriptor*> descriptors;
189201
{
190202
absl::MutexLock l(process_new_env_fns_mu);
191203
if (process_new_env_fns == nullptr) {
@@ -208,25 +220,25 @@ absl::Status CompilationEnvironments::InitializeAllKnownEnvs() {
208220
}
209221

210222
absl::Status CompilationEnvironments::AddEnv(
211-
std::unique_ptr<tsl::protobuf::Message> env) {
223+
std::unique_ptr<google::protobuf::Message> env) {
212224
if (!env) {
213225
return absl::InvalidArgumentError(
214226
"Can not add a null compilation environment.");
215227
}
216-
const tsl::protobuf::Descriptor& descriptor = *env->GetDescriptor();
228+
const google::protobuf::Descriptor& descriptor = *env->GetDescriptor();
217229
return AddEnvImpl(descriptor, std::move(env));
218230
}
219231

220232
CompilationEnvironmentsProto CompilationEnvironments::ToProto() const {
221233
// Sort the environments by their message types' full names so that the
222234
// proto fields are deterministically ordered.
223-
std::vector<const tsl::protobuf::Descriptor*> descriptors;
235+
std::vector<const google::protobuf::Descriptor*> descriptors;
224236
descriptors.reserve(environments_.size());
225237
for (const auto& [descriptor, message] : environments_) {
226238
descriptors.push_back(descriptor);
227239
}
228-
absl::c_sort(descriptors, [](const tsl::protobuf::Descriptor* lhs,
229-
const tsl::protobuf::Descriptor* rhs) {
240+
absl::c_sort(descriptors, [](const google::protobuf::Descriptor* lhs,
241+
const google::protobuf::Descriptor* rhs) {
230242
return lhs->full_name() < rhs->full_name();
231243
});
232244

@@ -239,7 +251,7 @@ CompilationEnvironmentsProto CompilationEnvironments::ToProto() const {
239251

240252
CompilationEnvironments::ProcessNewEnvFn
241253
CompilationEnvironments::GetProcessNewEnvFn(
242-
const tsl::protobuf::Descriptor& descriptor) {
254+
const google::protobuf::Descriptor& descriptor) {
243255
absl::MutexLock l(process_new_env_fns_mu);
244256
if (process_new_env_fns == nullptr) {
245257
return nullptr;
@@ -262,8 +274,8 @@ void CompilationEnvironments::EnvAdded(absl::string_view env_type) {
262274
}
263275

264276
absl::Status CompilationEnvironments::AddEnvImpl(
265-
const tsl::protobuf::Descriptor& descriptor,
266-
std::unique_ptr<tsl::protobuf::Message> env) {
277+
const google::protobuf::Descriptor& descriptor,
278+
std::unique_ptr<google::protobuf::Message> env) {
267279
// Check if we already have an environment of env's type
268280
if (environments_.contains(&descriptor)) {
269281
return absl::AlreadyExistsError(absl::StrCat(
@@ -276,16 +288,16 @@ absl::Status CompilationEnvironments::AddEnvImpl(
276288
return absl::InvalidArgumentError(absl::StrCat(
277289
"Unknown CompilationEnvironment type ", descriptor.full_name()));
278290
}
279-
TF_ASSIGN_OR_RETURN(std::unique_ptr<tsl::protobuf::Message> processed_env,
291+
TF_ASSIGN_OR_RETURN(std::unique_ptr<google::protobuf::Message> processed_env,
280292
process_new_env(std::move(env)));
281293

282294
// Check for unknown fields
283-
const tsl::protobuf::UnknownFieldSet& unknown_fields =
295+
const google::protobuf::UnknownFieldSet& unknown_fields =
284296
processed_env->GetReflection()->GetUnknownFields(*processed_env);
285297
std::vector<int> unknown_tags;
286298
unknown_tags.reserve(unknown_fields.field_count());
287299
for (int i = 0; i < unknown_fields.field_count(); ++i) {
288-
const tsl::protobuf::UnknownField& field = unknown_fields.field(i);
300+
const google::protobuf::UnknownField& field = unknown_fields.field(i);
289301
unknown_tags.push_back(field.number());
290302
}
291303
if (!unknown_tags.empty()) {

xla/service/compilation_environments.h

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ limitations under the License.
2424
#include "absl/status/status.h"
2525
#include "absl/status/statusor.h"
2626
#include "absl/strings/string_view.h"
27+
#include "google/protobuf/descriptor.h"
2728
#include "xla/xla.pb.h"
28-
#include "tsl/platform/protobuf.h"
2929

3030
namespace xla {
3131

@@ -34,7 +34,7 @@ namespace xla {
3434
//
3535
// CompilationEnvironments uses lazy initialization, (see GetEnv() for more
3636
// details). Lazy initialization is used so we can avoid:
37-
// A) Requiring every code path to explitily construct all needed compilation
37+
// A) Requiring every code path to explicitly construct all needed compilation
3838
// environments, particularly when the default constructed environment is
3939
// all we need AND
4040
// B) Requiring CompilationEnvironments to implicitly construct all needed
@@ -45,8 +45,8 @@ namespace xla {
4545
class CompilationEnvironments {
4646
public:
4747
using ProcessNewEnvFn =
48-
std::function<absl::StatusOr<std::unique_ptr<tsl::protobuf::Message>>(
49-
std::unique_ptr<tsl::protobuf::Message>)>;
48+
std::function<absl::StatusOr<std::unique_ptr<google::protobuf::Message>>(
49+
std::unique_ptr<google::protobuf::Message>)>;
5050

5151
CompilationEnvironments() = default;
5252
CompilationEnvironments(const CompilationEnvironments& rhs) { *this = rhs; }
@@ -72,17 +72,22 @@ class CompilationEnvironments {
7272
//
7373
// REQUIRES:
7474
// - The output is *not* allowed to be null, even for null input.
75-
static void RegisterProcessNewEnvFn(
76-
const tsl::protobuf::Descriptor* descriptor,
77-
ProcessNewEnvFn process_new_env);
75+
// - `descriptor` must stay alive until the process ends, or until
76+
// `DeregisterProcessNewEnvFn` is called.
77+
static void RegisterProcessNewEnvFn(const google::protobuf::Descriptor* descriptor,
78+
ProcessNewEnvFn process_new_env);
79+
80+
// Deregisters the ProcessNewEnvFn for the given proto descriptor, if one
81+
// exists.
82+
static void DeregisterProcessNewEnvFn(const google::protobuf::Descriptor* descriptor);
7883

7984
// Adds env to the list of CompilationEnvironments. If an environment with
8085
// the same proto descriptor has already been added, returns an error.
8186
//
8287
// All added environments are processed via registered ProcessNewEnvFns. If
8388
// such a function was not regitered for env's proto descriptor or env's
8489
// proto type is unknown, an error will be returned.
85-
absl::Status AddEnv(std::unique_ptr<tsl::protobuf::Message> env);
90+
absl::Status AddEnv(std::unique_ptr<google::protobuf::Message> env);
8691

8792
// Returns the CompilationEnvironment corresponding to T. If such an
8893
// environment has not been added, ProcessNewEnvFn(nullptr) will be added and
@@ -116,7 +121,7 @@ class CompilationEnvironments {
116121
// Returns the ProcessNewEnvFn for the given env type. Returns nullptr if no
117122
// ProcessNewEnvFn has been registered for the env type.
118123
static ProcessNewEnvFn GetProcessNewEnvFn(
119-
const tsl::protobuf::Descriptor& descriptor);
124+
const google::protobuf::Descriptor& descriptor);
120125

121126
// Called by GetEnv(), when it lazily creates a new environment, to globally
122127
// track stats about how many such environments are created by
@@ -128,11 +133,11 @@ class CompilationEnvironments {
128133
// are added to CompilationEnvironments.
129134
static void EnvAdded(absl::string_view env_type);
130135

131-
absl::Status AddEnvImpl(const tsl::protobuf::Descriptor& descriptor,
132-
std::unique_ptr<tsl::protobuf::Message> env);
136+
absl::Status AddEnvImpl(const google::protobuf::Descriptor& descriptor,
137+
std::unique_ptr<google::protobuf::Message> env);
133138

134-
absl::flat_hash_map<const tsl::protobuf::Descriptor*,
135-
std::unique_ptr<tsl::protobuf::Message>>
139+
absl::flat_hash_map<const google::protobuf::Descriptor*,
140+
std::unique_ptr<google::protobuf::Message>>
136141
environments_;
137142
};
138143

@@ -158,7 +163,7 @@ T& CompilationEnvironments::GetMutableEnv() {
158163
it = environments_.find(descriptor);
159164
}
160165

161-
return tsl::protobuf::DownCastToGenerated<T>(*it->second);
166+
return google::protobuf::DownCastToGenerated<T>(*it->second);
162167
}
163168

164169
template <typename T>

0 commit comments

Comments
 (0)