@@ -28,22 +28,23 @@ limitations under the License.
2828#include " absl/container/flat_hash_map.h"
2929#include " absl/log/check.h"
3030#include " absl/log/log.h"
31- #include " absl/memory/memory.h"
3231#include " absl/status/status.h"
3332#include " absl/status/statusor.h"
3433#include " absl/strings/str_cat.h"
3534#include " absl/strings/str_join.h"
3635#include " absl/strings/string_view.h"
3736#include " absl/synchronization/mutex.h"
37+ #include " google/protobuf/descriptor.h"
38+ #include " google/protobuf/message.h"
39+ #include " google/protobuf/unknown_field_set.h"
3840#include " xla/tsl/platform/errors.h"
3941#include " xla/tsl/platform/statusor.h"
40- #include " tsl/platform/protobuf.h"
4142
4243namespace xla {
4344namespace {
4445
4546ABSL_CONST_INIT absl::Mutex process_new_env_fns_mu (absl::kConstInit );
46- absl::flat_hash_map<const tsl ::protobuf::Descriptor*,
47+ absl::flat_hash_map<const google ::protobuf::Descriptor*,
4748 CompilationEnvironments::ProcessNewEnvFn>*
4849 process_new_env_fns ABSL_GUARDED_BY (process_new_env_fns_mu) = nullptr;
4950
@@ -128,8 +129,8 @@ CompilationEnvironments::CreateFromProto(
128129 const CompilationEnvironmentsProto& proto) {
129130 auto envs = std::make_unique<CompilationEnvironments>();
130131
131- const tsl ::protobuf::DescriptorPool* const pool =
132- tsl ::protobuf::DescriptorPool::generated_pool ();
132+ const google ::protobuf::DescriptorPool* const pool =
133+ google ::protobuf::DescriptorPool::generated_pool ();
133134
134135 for (const auto & env_proto : proto.environments ()) {
135136 std::string fullname;
@@ -140,22 +141,21 @@ CompilationEnvironments::CreateFromProto(
140141 env_proto.type_url ()));
141142 }
142143
143- const tsl ::protobuf::Descriptor* const descriptor =
144+ const google ::protobuf::Descriptor* const descriptor =
144145 pool->FindMessageTypeByName (fullname);
145146 if (descriptor == nullptr ) {
146147 return absl::DataLossError (absl::StrCat (
147148 " Unknown CompilationEnvironment message type: " , fullname));
148149 }
149150
150- const tsl::protobuf::Message* const prototype =
151- tsl::protobuf::MessageFactory::generated_factory ()->GetPrototype (
152- descriptor);
151+ const google::protobuf::Message* const prototype =
152+ google::protobuf::MessageFactory::generated_factory ()->GetPrototype (descriptor);
153153 if (prototype == nullptr ) {
154154 return absl::InternalError (absl::StrCat (
155155 " Unsupported CompilationEnvironment message type: " , fullname));
156156 }
157157
158- std::unique_ptr<tsl ::protobuf::Message> env (prototype->New ());
158+ std::unique_ptr<google ::protobuf::Message> env (prototype->New ());
159159 if (!env_proto.UnpackTo (env.get ())) {
160160 return absl::DataLossError (absl::StrCat (
161161 " Unable to unpack CompilationEnvironment message of type '" , fullname,
@@ -169,12 +169,11 @@ CompilationEnvironments::CreateFromProto(
169169}
170170
171171void CompilationEnvironments::RegisterProcessNewEnvFn (
172- const tsl::protobuf::Descriptor* descriptor,
173- ProcessNewEnvFn process_new_env) {
172+ const google::protobuf::Descriptor* descriptor, ProcessNewEnvFn process_new_env) {
174173 absl::MutexLock l (process_new_env_fns_mu);
175174 if (process_new_env_fns == nullptr ) {
176175 process_new_env_fns =
177- new absl::flat_hash_map<const tsl ::protobuf::Descriptor*,
176+ new absl::flat_hash_map<const google ::protobuf::Descriptor*,
178177 CompilationEnvironments::ProcessNewEnvFn>();
179178 }
180179 const bool inserted =
@@ -184,8 +183,21 @@ void CompilationEnvironments::RegisterProcessNewEnvFn(
184183 << descriptor->full_name () << " ' has already been registered" ;
185184}
186185
186+ void CompilationEnvironments::DeregisterProcessNewEnvFn (
187+ const google::protobuf::Descriptor* descriptor) {
188+ absl::MutexLock l (process_new_env_fns_mu);
189+ if (process_new_env_fns == nullptr ) {
190+ return ;
191+ }
192+ const auto it = process_new_env_fns->find (descriptor);
193+ if (it == process_new_env_fns->end ()) {
194+ return ;
195+ }
196+ process_new_env_fns->erase (it);
197+ }
198+
187199absl::Status CompilationEnvironments::InitializeAllKnownEnvs () {
188- std::vector<const tsl ::protobuf::Descriptor*> descriptors;
200+ std::vector<const google ::protobuf::Descriptor*> descriptors;
189201 {
190202 absl::MutexLock l (process_new_env_fns_mu);
191203 if (process_new_env_fns == nullptr ) {
@@ -208,25 +220,25 @@ absl::Status CompilationEnvironments::InitializeAllKnownEnvs() {
208220}
209221
210222absl::Status CompilationEnvironments::AddEnv (
211- std::unique_ptr<tsl ::protobuf::Message> env) {
223+ std::unique_ptr<google ::protobuf::Message> env) {
212224 if (!env) {
213225 return absl::InvalidArgumentError (
214226 " Can not add a null compilation environment." );
215227 }
216- const tsl ::protobuf::Descriptor& descriptor = *env->GetDescriptor ();
228+ const google ::protobuf::Descriptor& descriptor = *env->GetDescriptor ();
217229 return AddEnvImpl (descriptor, std::move (env));
218230}
219231
220232CompilationEnvironmentsProto CompilationEnvironments::ToProto () const {
221233 // Sort the environments by their message types' full names so that the
222234 // proto fields are deterministically ordered.
223- std::vector<const tsl ::protobuf::Descriptor*> descriptors;
235+ std::vector<const google ::protobuf::Descriptor*> descriptors;
224236 descriptors.reserve (environments_.size ());
225237 for (const auto & [descriptor, message] : environments_) {
226238 descriptors.push_back (descriptor);
227239 }
228- absl::c_sort (descriptors, [](const tsl ::protobuf::Descriptor* lhs,
229- const tsl ::protobuf::Descriptor* rhs) {
240+ absl::c_sort (descriptors, [](const google ::protobuf::Descriptor* lhs,
241+ const google ::protobuf::Descriptor* rhs) {
230242 return lhs->full_name () < rhs->full_name ();
231243 });
232244
@@ -239,7 +251,7 @@ CompilationEnvironmentsProto CompilationEnvironments::ToProto() const {
239251
240252CompilationEnvironments::ProcessNewEnvFn
241253CompilationEnvironments::GetProcessNewEnvFn (
242- const tsl ::protobuf::Descriptor& descriptor) {
254+ const google ::protobuf::Descriptor& descriptor) {
243255 absl::MutexLock l (process_new_env_fns_mu);
244256 if (process_new_env_fns == nullptr ) {
245257 return nullptr ;
@@ -262,8 +274,8 @@ void CompilationEnvironments::EnvAdded(absl::string_view env_type) {
262274}
263275
264276absl::Status CompilationEnvironments::AddEnvImpl (
265- const tsl ::protobuf::Descriptor& descriptor,
266- std::unique_ptr<tsl ::protobuf::Message> env) {
277+ const google ::protobuf::Descriptor& descriptor,
278+ std::unique_ptr<google ::protobuf::Message> env) {
267279 // Check if we already have an environment of env's type
268280 if (environments_.contains (&descriptor)) {
269281 return absl::AlreadyExistsError (absl::StrCat (
@@ -276,16 +288,16 @@ absl::Status CompilationEnvironments::AddEnvImpl(
276288 return absl::InvalidArgumentError (absl::StrCat (
277289 " Unknown CompilationEnvironment type " , descriptor.full_name ()));
278290 }
279- TF_ASSIGN_OR_RETURN (std::unique_ptr<tsl ::protobuf::Message> processed_env,
291+ TF_ASSIGN_OR_RETURN (std::unique_ptr<google ::protobuf::Message> processed_env,
280292 process_new_env (std::move (env)));
281293
282294 // Check for unknown fields
283- const tsl ::protobuf::UnknownFieldSet& unknown_fields =
295+ const google ::protobuf::UnknownFieldSet& unknown_fields =
284296 processed_env->GetReflection ()->GetUnknownFields (*processed_env);
285297 std::vector<int > unknown_tags;
286298 unknown_tags.reserve (unknown_fields.field_count ());
287299 for (int i = 0 ; i < unknown_fields.field_count (); ++i) {
288- const tsl ::protobuf::UnknownField& field = unknown_fields.field (i);
300+ const google ::protobuf::UnknownField& field = unknown_fields.field (i);
289301 unknown_tags.push_back (field.number ());
290302 }
291303 if (!unknown_tags.empty ()) {
0 commit comments