@@ -25,8 +25,8 @@ limitations under the License.
2525#include " tensorflow/core/framework/tensor_shape.h"
2626#include " tensorflow/core/kernels/task_runner.h"
2727#include " tensorflow/core/kernels/unique_ali_op_util.h"
28- #include " tensorflow/core/lib/hash/hash.h"
2928#include " tensorflow/core/lib/core/status.h"
29+ #include " tensorflow/core/lib/hash/hash.h"
3030#include " tensorflow/core/util/env_var.h"
3131
3232namespace tensorflow {
@@ -41,40 +41,43 @@ const char* kStlHashMapString = "STL";
4141const char * kAbslHashMapString = " ABSL" ;
4242const char * kGoogleHashMapString = " GOOGLE" ;
4343const int64 kDefaultUniqueRatioHint = 4 ;
44- }
44+ } // namespace
4545
4646template <typename T, typename TIndex>
4747class UniqueAliOp : public OpKernel {
4848 public:
4949 explicit UniqueAliOp (OpKernelConstruction* context) : OpKernel(context) {
50- OP_REQUIRES_OK (context, ReadInt64FromEnvVar (kUniqueOpPartitionSizeEnv ,
51- kPartitionSize , &partition_size_));
52- OP_REQUIRES (context, partition_size_ > 0 ,
53- errors::InvalidArgument (" Invaild PARTITION_SIZE=" ,
54- partition_size_));
50+ OP_REQUIRES_OK (
51+ context, ReadInt64FromEnvVar (kUniqueOpPartitionSizeEnv , kPartitionSize ,
52+ &partition_size_));
53+ OP_REQUIRES (
54+ context, partition_size_ > 0 ,
55+ errors::InvalidArgument (" Invaild PARTITION_SIZE=" , partition_size_));
5556
56- OP_REQUIRES_OK (context, ReadBoolFromEnvVar ( kUniqueOpSerialEnv ,
57- false , &serial_));
57+ OP_REQUIRES_OK (context,
58+ ReadBoolFromEnvVar ( kUniqueOpSerialEnv , false , &serial_));
5859
5960 // NOTE(zycao>: Hash map insertion and lookup performance is dominating in
6061 // Unique Op. Based on benchmark results, 'google::dense_hash_map' will be
6162 // used as default for most key types except string.
6263 //
63- // By setting "DEEPREC_UNIQUE_OP_HASH_MAP" environment variable, a particular
64- // hash map could be seleteed to use. Possible choices are listed below:
64+ // By setting "DEEPREC_UNIQUE_OP_HASH_MAP" environment variable, a
65+ // particular hash map could be seleteed to use. Possible choices are listed
66+ // below:
6567 // "MULTIMAP" for multimap parrallel process,
6668 // "STL" for std::unordred_map,
6769 // "ABSL" for absl::flat_hash_map,
6870 // "GOOGLE" for google::dense_hash_map.
6971 std::string hash_map_str;
70- OP_REQUIRES_OK (context, ReadStringFromEnvVar ( kUniqueOpHashMapEnv ,
71- kGoogleHashMapString ,
72- &hash_map_str));
72+ OP_REQUIRES_OK (
73+ context, ReadStringFromEnvVar ( kUniqueOpHashMapEnv , kGoogleHashMapString ,
74+ &hash_map_str));
7375 std::transform (hash_map_str.begin (), hash_map_str.end (),
7476 hash_map_str.begin (), ::toupper);
7577
7678 OP_REQUIRES_OK (context, ReadInt64FromEnvVar (kUniqueOpUniqRatioHint ,
77- kDefaultUniqueRatioHint , &unique_ratio_hint_));
79+ kDefaultUniqueRatioHint ,
80+ &unique_ratio_hint_));
7881 OP_REQUIRES (context, unique_ratio_hint_ > 0 ,
7982 errors::InvalidArgument (" Invaild " , kUniqueOpUniqRatioHint , " =" ,
8083 unique_ratio_hint_));
@@ -83,7 +86,8 @@ class UniqueAliOp : public OpKernel {
8386 map_flag_ = MULTIMAP;
8487 static char print_once = [] {
8588 LOG (INFO) << " MultiMapCompute preserved "
86- " dense hash map key: " << kPreseverdEmptyKey ;
89+ " dense hash map key: "
90+ << kPreseverdEmptyKey ;
8791 return ' \0 ' ;
8892 }();
8993 } else if (!hash_map_str.compare (kStlHashMapString )) {
@@ -95,7 +99,6 @@ class UniqueAliOp : public OpKernel {
9599 } else {
96100 map_flag_ = GOOGLE;
97101 }
98-
99102 }
100103
101104 void Compute (OpKernelContext* context) override {
@@ -110,16 +113,14 @@ class UniqueAliOp : public OpKernel {
110113 Tensor output;
111114 Tensor output_counter;
112115 if (context->num_inputs () == 1 ) {
113- UniqueWithoutAxis<T, TIndex>(context, input,
114- &idx, &output, &output_counter, num_outputs (),
115- partition_size_, serial_, unique_ratio_hint_,
116- map_flag_);
116+ UniqueWithoutAxis<T, TIndex>(
117+ context, input, &idx, &output, &output_counter, num_outputs (),
118+ partition_size_, serial_, unique_ratio_hint_, map_flag_);
117119 } else {
118120 const Tensor& axis_tensor = context->input (1 );
119- UniqueWithAxis<T, TIndex>(context, input,
120- axis_tensor, &idx, &output, &output_counter,
121- num_outputs (), partition_size_, serial_,
122- unique_ratio_hint_, map_flag_);
121+ UniqueWithAxis<T, TIndex>(context, input, axis_tensor, &idx, &output,
122+ &output_counter, num_outputs (), partition_size_,
123+ serial_, unique_ratio_hint_, map_flag_);
123124 }
124125 context->set_output (0 , output);
125126 context->set_output (1 , idx);
@@ -128,33 +129,65 @@ class UniqueAliOp : public OpKernel {
128129 }
129130 }
130131
132+ protected:
131133 bool serial_ = false ;
132134 int64 partition_size_ = 0 ;
133135 int64 unique_ratio_hint_;
134136 UniqueMaps map_flag_ = GOOGLE; // "GOOGLE" dense hash map is default
135137};
136138
139+ template <typename T, typename TIndex>
140+ class UniqueWithCountAliOp : public UniqueAliOp <T, TIndex> {
141+ using UniqueAliOp<T, TIndex>::serial_;
142+ using UniqueAliOp<T, TIndex>::partition_size_;
143+ using UniqueAliOp<T, TIndex>::unique_ratio_hint_;
144+ using UniqueAliOp<T, TIndex>::map_flag_;
145+ using OpKernel::num_outputs;
146+
147+ public:
148+ explicit UniqueWithCountAliOp (OpKernelConstruction* context)
149+ : UniqueAliOp<T, TIndex>(context) {
150+ OP_REQUIRES_OK (context, context->GetAttr (" N" , &num_sparse_));
151+ }
152+
153+ void Compute (OpKernelContext* context) override {
154+ const Tensor& input = context->input (0 );
155+ Tensor idx;
156+ Tensor output;
157+ Tensor output_counter;
158+ UniqueWithExtraCounts<T, TIndex>(
159+ context, input, &idx, &output, &output_counter, num_outputs (),
160+ partition_size_, serial_, unique_ratio_hint_, num_sparse_, map_flag_);
161+ context->set_output (0 , output);
162+ context->set_output (1 , idx);
163+ context->set_output (2 , output_counter);
164+ }
165+
166+ private:
167+ int num_sparse_;
168+ };
169+
137170#define REGISTER_UNIQUE (type ) \
138171 REGISTER_KERNEL_BUILDER (Name(" Unique" ) \
139172 .Device(DEVICE_CPU) \
140173 .TypeConstraint<type>(" T" ) \
141174 .TypeConstraint<int32>(" out_idx" ), \
142- UniqueAliOp<type, int32>); \
175+ UniqueAliOp<type, int32>) \
143176 REGISTER_KERNEL_BUILDER (Name(" Unique" ) \
144177 .Device(DEVICE_CPU) \
145178 .TypeConstraint<type>(" T" ) \
146179 .TypeConstraint<int64>(" out_idx" ), \
147- UniqueAliOp<type, int64>); \
180+ UniqueAliOp<type, int64>) \
148181 REGISTER_KERNEL_BUILDER (Name(" UniqueV2" ) \
149182 .Device(DEVICE_CPU) \
150183 .TypeConstraint<type>(" T" ) \
151184 .TypeConstraint<int32>(" out_idx" ), \
152- UniqueAliOp<type, int32>); \
185+ UniqueAliOp<type, int32>) \
153186 REGISTER_KERNEL_BUILDER (Name(" UniqueV2" ) \
154187 .Device(DEVICE_CPU) \
155188 .TypeConstraint<type>(" T" ) \
156189 .TypeConstraint<int64>(" out_idx" ), \
157- UniqueAliOp<type, int64>); \
190+ UniqueAliOp<type, int64>) \
158191 REGISTER_KERNEL_BUILDER (Name(" UniqueWithCounts" ) \
159192 .Device(DEVICE_CPU) \
160193 .TypeConstraint<type>(" T" ) \
@@ -164,7 +197,7 @@ class UniqueAliOp : public OpKernel {
164197 .Device(DEVICE_CPU) \
165198 .TypeConstraint<type>(" T" ) \
166199 .TypeConstraint<int64>(" out_idx" ), \
167- UniqueAliOp<type, int64>); \
200+ UniqueAliOp<type, int64>) \
168201 REGISTER_KERNEL_BUILDER (Name(" UniqueWithCountsV2" ) \
169202 .Device(DEVICE_CPU) \
170203 .TypeConstraint<type>(" T" ) \
@@ -174,7 +207,17 @@ class UniqueAliOp : public OpKernel {
174207 .Device(DEVICE_CPU) \
175208 .TypeConstraint<type>(" T" ) \
176209 .TypeConstraint<int64>(" out_idx" ), \
177- UniqueAliOp<type, int64>)
210+ UniqueAliOp<type, int64>) \
211+ REGISTER_KERNEL_BUILDER (Name(" UniqueWithExtraCounts" ) \
212+ .Device(DEVICE_CPU) \
213+ .TypeConstraint<type>(" T" ) \
214+ .TypeConstraint<int32>(" out_idx" ), \
215+ UniqueWithCountAliOp<type, int32>) \
216+ REGISTER_KERNEL_BUILDER (Name(" UniqueWithExtraCounts" ) \
217+ .Device(DEVICE_CPU) \
218+ .TypeConstraint<type>(" T" ) \
219+ .TypeConstraint<int64>(" out_idx" ), \
220+ UniqueWithCountAliOp<type, int64>)
178221TF_CALL_REAL_NUMBER_TYPES (REGISTER_UNIQUE);
179222REGISTER_UNIQUE (string)
180223#undef REGISTER_UNIQUE
@@ -198,12 +241,22 @@ REGISTER_UNIQUE(string)
198241 .HostMemory(" count" ) \
199242 .TypeConstraint<type>(" T" ) \
200243 .TypeConstraint<int64>(" out_idx" ), \
201- UniqueAliOp<type, int64>);
244+ UniqueAliOp<type, int64>) \
245+ REGISTER_KERNEL_BUILDER(Name(" UniqueWithExtraCounts" ) \
246+ .Device(DEVICE_GPU) \
247+ .TypeConstraint<type>(" T" ) \
248+ .TypeConstraint<int32>(" out_idx" ), \
249+ UniqueWithCountAliOp<type, int32>) \
250+ REGISTER_KERNEL_BUILDER(Name(" UniqueWithExtraCounts" ) \
251+ .Device(DEVICE_GPU) \
252+ .TypeConstraint<type>(" T" ) \
253+ .TypeConstraint<int64>(" out_idx" ), \
254+ UniqueWithCountAliOp<type, int64>);
202255TF_CALL_REAL_NUMBER_TYPES (REGISTER_UNIQUE);
203256REGISTER_UNIQUE (string)
204257#undef REGISTER_UNIQUE
205- #endif // GOOGLE_CUDA
206-
258+ #endif // GOOGLE_CUDA
259+
207260#ifdef TENSORFLOW_USE_SYCL
208261REGISTER_KERNEL_BUILDER (Name(" Unique" )
209262 .Device(DEVICE_SYCL)
0 commit comments