Skip to content

Commit 2b15e8a

Browse files
authored
[Embedding] Fix shared embedding frequency counting problem. (#962)
Signed-off-by: 泊霆 <[email protected]> Co-authored-by: 泊霆 <[email protected]>
1 parent d84837f commit 2b15e8a

File tree

12 files changed

+386
-72
lines changed

12 files changed

+386
-72
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
op {
2+
graph_op_name: "UniqueWithExtraCounts"
3+
visibility: HIDDEN
4+
}
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
op {
2+
graph_op_name: "UniqueWithExtraCounts"
3+
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
op {
2+
graph_op_name: "UniqueWithExtraCounts"
3+
visibility: HIDDEN
4+
}

tensorflow/core/kernels/unique_ali_op.cc

Lines changed: 87 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ limitations under the License.
2525
#include "tensorflow/core/framework/tensor_shape.h"
2626
#include "tensorflow/core/kernels/task_runner.h"
2727
#include "tensorflow/core/kernels/unique_ali_op_util.h"
28-
#include "tensorflow/core/lib/hash/hash.h"
2928
#include "tensorflow/core/lib/core/status.h"
29+
#include "tensorflow/core/lib/hash/hash.h"
3030
#include "tensorflow/core/util/env_var.h"
3131

3232
namespace tensorflow {
@@ -41,40 +41,43 @@ const char* kStlHashMapString = "STL";
4141
const char* kAbslHashMapString = "ABSL";
4242
const char* kGoogleHashMapString = "GOOGLE";
4343
const int64 kDefaultUniqueRatioHint = 4;
44-
}
44+
} // namespace
4545

4646
template <typename T, typename TIndex>
4747
class UniqueAliOp : public OpKernel {
4848
public:
4949
explicit UniqueAliOp(OpKernelConstruction* context) : OpKernel(context) {
50-
OP_REQUIRES_OK(context, ReadInt64FromEnvVar(kUniqueOpPartitionSizeEnv,
51-
kPartitionSize, &partition_size_));
52-
OP_REQUIRES(context, partition_size_ > 0,
53-
errors::InvalidArgument("Invaild PARTITION_SIZE=",
54-
partition_size_));
50+
OP_REQUIRES_OK(
51+
context, ReadInt64FromEnvVar(kUniqueOpPartitionSizeEnv, kPartitionSize,
52+
&partition_size_));
53+
OP_REQUIRES(
54+
context, partition_size_ > 0,
55+
errors::InvalidArgument("Invaild PARTITION_SIZE=", partition_size_));
5556

56-
OP_REQUIRES_OK(context, ReadBoolFromEnvVar(kUniqueOpSerialEnv,
57-
false, &serial_));
57+
OP_REQUIRES_OK(context,
58+
ReadBoolFromEnvVar(kUniqueOpSerialEnv, false, &serial_));
5859

5960
// NOTE(zycao>: Hash map insertion and lookup performance is dominating in
6061
// Unique Op. Based on benchmark results, 'google::dense_hash_map' will be
6162
// used as default for most key types except string.
6263
//
63-
// By setting "DEEPREC_UNIQUE_OP_HASH_MAP" environment variable, a particular
64-
// hash map could be seleteed to use. Possible choices are listed below:
64+
// By setting "DEEPREC_UNIQUE_OP_HASH_MAP" environment variable, a
65+
// particular hash map could be seleteed to use. Possible choices are listed
66+
// below:
6567
// "MULTIMAP" for multimap parrallel process,
6668
// "STL" for std::unordred_map,
6769
// "ABSL" for absl::flat_hash_map,
6870
// "GOOGLE" for google::dense_hash_map.
6971
std::string hash_map_str;
70-
OP_REQUIRES_OK(context, ReadStringFromEnvVar(kUniqueOpHashMapEnv,
71-
kGoogleHashMapString,
72-
&hash_map_str));
72+
OP_REQUIRES_OK(
73+
context, ReadStringFromEnvVar(kUniqueOpHashMapEnv, kGoogleHashMapString,
74+
&hash_map_str));
7375
std::transform(hash_map_str.begin(), hash_map_str.end(),
7476
hash_map_str.begin(), ::toupper);
7577

7678
OP_REQUIRES_OK(context, ReadInt64FromEnvVar(kUniqueOpUniqRatioHint,
77-
kDefaultUniqueRatioHint, &unique_ratio_hint_));
79+
kDefaultUniqueRatioHint,
80+
&unique_ratio_hint_));
7881
OP_REQUIRES(context, unique_ratio_hint_ > 0,
7982
errors::InvalidArgument("Invaild ", kUniqueOpUniqRatioHint, "=",
8083
unique_ratio_hint_));
@@ -83,7 +86,8 @@ class UniqueAliOp : public OpKernel {
8386
map_flag_ = MULTIMAP;
8487
static char print_once = [] {
8588
LOG(INFO) << "MultiMapCompute preserved "
86-
"dense hash map key: " << kPreseverdEmptyKey;
89+
"dense hash map key: "
90+
<< kPreseverdEmptyKey;
8791
return '\0';
8892
}();
8993
} else if (!hash_map_str.compare(kStlHashMapString)) {
@@ -95,7 +99,6 @@ class UniqueAliOp : public OpKernel {
9599
} else {
96100
map_flag_ = GOOGLE;
97101
}
98-
99102
}
100103

101104
void Compute(OpKernelContext* context) override {
@@ -110,16 +113,14 @@ class UniqueAliOp : public OpKernel {
110113
Tensor output;
111114
Tensor output_counter;
112115
if (context->num_inputs() == 1) {
113-
UniqueWithoutAxis<T, TIndex>(context, input,
114-
&idx, &output, &output_counter, num_outputs(),
115-
partition_size_, serial_, unique_ratio_hint_,
116-
map_flag_);
116+
UniqueWithoutAxis<T, TIndex>(
117+
context, input, &idx, &output, &output_counter, num_outputs(),
118+
partition_size_, serial_, unique_ratio_hint_, map_flag_);
117119
} else {
118120
const Tensor& axis_tensor = context->input(1);
119-
UniqueWithAxis<T, TIndex>(context, input,
120-
axis_tensor, &idx, &output, &output_counter,
121-
num_outputs(), partition_size_, serial_,
122-
unique_ratio_hint_, map_flag_);
121+
UniqueWithAxis<T, TIndex>(context, input, axis_tensor, &idx, &output,
122+
&output_counter, num_outputs(), partition_size_,
123+
serial_, unique_ratio_hint_, map_flag_);
123124
}
124125
context->set_output(0, output);
125126
context->set_output(1, idx);
@@ -128,33 +129,65 @@ class UniqueAliOp : public OpKernel {
128129
}
129130
}
130131

132+
protected:
131133
bool serial_ = false;
132134
int64 partition_size_ = 0;
133135
int64 unique_ratio_hint_;
134136
UniqueMaps map_flag_ = GOOGLE; // "GOOGLE" dense hash map is default
135137
};
136138

139+
template <typename T, typename TIndex>
140+
class UniqueWithCountAliOp : public UniqueAliOp<T, TIndex> {
141+
using UniqueAliOp<T, TIndex>::serial_;
142+
using UniqueAliOp<T, TIndex>::partition_size_;
143+
using UniqueAliOp<T, TIndex>::unique_ratio_hint_;
144+
using UniqueAliOp<T, TIndex>::map_flag_;
145+
using OpKernel::num_outputs;
146+
147+
public:
148+
explicit UniqueWithCountAliOp(OpKernelConstruction* context)
149+
: UniqueAliOp<T, TIndex>(context) {
150+
OP_REQUIRES_OK(context, context->GetAttr("N", &num_sparse_));
151+
}
152+
153+
void Compute(OpKernelContext* context) override {
154+
const Tensor& input = context->input(0);
155+
Tensor idx;
156+
Tensor output;
157+
Tensor output_counter;
158+
UniqueWithExtraCounts<T, TIndex>(
159+
context, input, &idx, &output, &output_counter, num_outputs(),
160+
partition_size_, serial_, unique_ratio_hint_, num_sparse_, map_flag_);
161+
context->set_output(0, output);
162+
context->set_output(1, idx);
163+
context->set_output(2, output_counter);
164+
}
165+
166+
private:
167+
int num_sparse_;
168+
};
169+
137170
#define REGISTER_UNIQUE(type) \
138171
REGISTER_KERNEL_BUILDER(Name("Unique") \
139172
.Device(DEVICE_CPU) \
140173
.TypeConstraint<type>("T") \
141174
.TypeConstraint<int32>("out_idx"), \
142-
UniqueAliOp<type, int32>); \
175+
UniqueAliOp<type, int32>) \
143176
REGISTER_KERNEL_BUILDER(Name("Unique") \
144177
.Device(DEVICE_CPU) \
145178
.TypeConstraint<type>("T") \
146179
.TypeConstraint<int64>("out_idx"), \
147-
UniqueAliOp<type, int64>); \
180+
UniqueAliOp<type, int64>) \
148181
REGISTER_KERNEL_BUILDER(Name("UniqueV2") \
149182
.Device(DEVICE_CPU) \
150183
.TypeConstraint<type>("T") \
151184
.TypeConstraint<int32>("out_idx"), \
152-
UniqueAliOp<type, int32>); \
185+
UniqueAliOp<type, int32>) \
153186
REGISTER_KERNEL_BUILDER(Name("UniqueV2") \
154187
.Device(DEVICE_CPU) \
155188
.TypeConstraint<type>("T") \
156189
.TypeConstraint<int64>("out_idx"), \
157-
UniqueAliOp<type, int64>); \
190+
UniqueAliOp<type, int64>) \
158191
REGISTER_KERNEL_BUILDER(Name("UniqueWithCounts") \
159192
.Device(DEVICE_CPU) \
160193
.TypeConstraint<type>("T") \
@@ -164,7 +197,7 @@ class UniqueAliOp : public OpKernel {
164197
.Device(DEVICE_CPU) \
165198
.TypeConstraint<type>("T") \
166199
.TypeConstraint<int64>("out_idx"), \
167-
UniqueAliOp<type, int64>); \
200+
UniqueAliOp<type, int64>) \
168201
REGISTER_KERNEL_BUILDER(Name("UniqueWithCountsV2") \
169202
.Device(DEVICE_CPU) \
170203
.TypeConstraint<type>("T") \
@@ -174,7 +207,17 @@ class UniqueAliOp : public OpKernel {
174207
.Device(DEVICE_CPU) \
175208
.TypeConstraint<type>("T") \
176209
.TypeConstraint<int64>("out_idx"), \
177-
UniqueAliOp<type, int64>)
210+
UniqueAliOp<type, int64>) \
211+
REGISTER_KERNEL_BUILDER(Name("UniqueWithExtraCounts") \
212+
.Device(DEVICE_CPU) \
213+
.TypeConstraint<type>("T") \
214+
.TypeConstraint<int32>("out_idx"), \
215+
UniqueWithCountAliOp<type, int32>) \
216+
REGISTER_KERNEL_BUILDER(Name("UniqueWithExtraCounts") \
217+
.Device(DEVICE_CPU) \
218+
.TypeConstraint<type>("T") \
219+
.TypeConstraint<int64>("out_idx"), \
220+
UniqueWithCountAliOp<type, int64>)
178221
TF_CALL_REAL_NUMBER_TYPES(REGISTER_UNIQUE);
179222
REGISTER_UNIQUE(string)
180223
#undef REGISTER_UNIQUE
@@ -198,12 +241,22 @@ REGISTER_UNIQUE(string)
198241
.HostMemory("count") \
199242
.TypeConstraint<type>("T") \
200243
.TypeConstraint<int64>("out_idx"), \
201-
UniqueAliOp<type, int64>);
244+
UniqueAliOp<type, int64>) \
245+
REGISTER_KERNEL_BUILDER(Name("UniqueWithExtraCounts") \
246+
.Device(DEVICE_GPU) \
247+
.TypeConstraint<type>("T") \
248+
.TypeConstraint<int32>("out_idx"), \
249+
UniqueWithCountAliOp<type, int32>) \
250+
REGISTER_KERNEL_BUILDER(Name("UniqueWithExtraCounts") \
251+
.Device(DEVICE_GPU) \
252+
.TypeConstraint<type>("T") \
253+
.TypeConstraint<int64>("out_idx"), \
254+
UniqueWithCountAliOp<type, int64>);
202255
TF_CALL_REAL_NUMBER_TYPES(REGISTER_UNIQUE);
203256
REGISTER_UNIQUE(string)
204257
#undef REGISTER_UNIQUE
205-
#endif //GOOGLE_CUDA
206-
258+
#endif // GOOGLE_CUDA
259+
207260
#ifdef TENSORFLOW_USE_SYCL
208261
REGISTER_KERNEL_BUILDER(Name("Unique")
209262
.Device(DEVICE_SYCL)

0 commit comments

Comments
 (0)