Skip to content

Commit 0e3c331

Browse files
authored
[Embedding] Make GroupEmbedding compatible with sequence feature_column interface. (#852)
Signed-off-by: JunqiHu <[email protected]>
1 parent ce39947 commit 0e3c331

12 files changed

+438
-304
lines changed

tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py

Lines changed: 122 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -32,108 +32,133 @@
3232
from tensorflow.python.ops import parsing_ops
3333
from tensorflow.python.ops import sparse_ops
3434
from tensorflow.python.ops import variable_scope
35+
from tensorflow.python.feature_column import group_embedding_column as gec
3536

3637
# pylint: disable=protected-access
3738

3839

3940
def sequence_input_layer(
40-
features,
41-
feature_columns,
42-
weight_collections=None,
43-
trainable=True):
44-
""""Builds input layer for sequence input.
45-
46-
All `feature_columns` must be sequence dense columns with the same
47-
`sequence_length`. The output of this method can be fed into sequence
48-
networks, such as RNN.
49-
50-
The output of this method is a 3D `Tensor` of shape `[batch_size, T, D]`.
51-
`T` is the maximum sequence length for this batch, which could differ from
52-
batch to batch.
53-
54-
If multiple `feature_columns` are given with `Di` `num_elements` each, their
55-
outputs are concatenated. So, the final `Tensor` has shape
56-
`[batch_size, T, D0 + D1 + ... + Dn]`.
57-
58-
Example:
59-
60-
```python
61-
rating = sequence_numeric_column('rating')
62-
watches = sequence_categorical_column_with_identity(
63-
'watches', num_buckets=1000)
64-
watches_embedding = embedding_column(watches, dimension=10)
65-
columns = [rating, watches]
66-
67-
features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
68-
input_layer, sequence_length = sequence_input_layer(features, columns)
69-
70-
rnn_cell = tf.compat.v1.nn.rnn_cell.BasicRNNCell(hidden_size)
71-
outputs, state = tf.compat.v1.nn.dynamic_rnn(
72-
rnn_cell, inputs=input_layer, sequence_length=sequence_length)
73-
```
74-
75-
Args:
76-
features: A dict mapping keys to tensors.
77-
feature_columns: An iterable of dense sequence columns. Valid columns are
78-
- `embedding_column` that wraps a `sequence_categorical_column_with_*`
79-
- `sequence_numeric_column`.
80-
weight_collections: A list of collection names to which the Variable will be
81-
added. Note that variables will also be added to collections
82-
`tf.GraphKeys.GLOBAL_VARIABLES` and `ops.GraphKeys.MODEL_VARIABLES`.
83-
trainable: If `True` also add the variable to the graph collection
84-
`GraphKeys.TRAINABLE_VARIABLES`.
85-
86-
Returns:
87-
An `(input_layer, sequence_length)` tuple where:
88-
- input_layer: A float `Tensor` of shape `[batch_size, T, D]`.
89-
`T` is the maximum sequence length for this batch, which could differ
90-
from batch to batch. `D` is the sum of `num_elements` for all
91-
`feature_columns`.
92-
- sequence_length: An int `Tensor` of shape `[batch_size]`. The sequence
93-
length for each example.
94-
95-
Raises:
96-
ValueError: If any of the `feature_columns` is the wrong type.
97-
"""
98-
feature_columns = fc._normalize_feature_columns(feature_columns)
99-
for c in feature_columns:
100-
if not isinstance(c, fc._SequenceDenseColumn):
101-
raise ValueError(
102-
'All feature_columns must be of type _SequenceDenseColumn. '
103-
'You can wrap a sequence_categorical_column with an embedding_column '
104-
'or indicator_column. '
105-
'Given (type {}): {}'.format(type(c), c))
106-
107-
with variable_scope.variable_scope(
108-
None, default_name='sequence_input_layer', values=features.values()):
109-
builder = fc._LazyBuilder(features)
110-
output_tensors = []
111-
sequence_lengths = []
112-
ordered_columns = []
113-
114-
for column in sorted(feature_columns, key=lambda x: x.name):
115-
ordered_columns.append(column)
116-
with variable_scope.variable_scope(
117-
None, default_name=column._var_scope_name):
118-
dense_tensor, sequence_length = column._get_sequence_dense_tensor(
119-
builder,
120-
weight_collections=weight_collections,
121-
trainable=trainable)
122-
# Flattens the final dimension to produce a 3D Tensor.
123-
num_elements = column._variable_shape.num_elements()
124-
shape = array_ops.shape(dense_tensor)
125-
target_shape = [shape[0], shape[1], num_elements]
126-
output_tensors.append(
127-
array_ops.reshape(dense_tensor, shape=target_shape))
128-
sequence_lengths.append(sequence_length)
129-
130-
fc._verify_static_batch_size_equality(output_tensors, ordered_columns)
131-
fc._verify_static_batch_size_equality(sequence_lengths, ordered_columns)
132-
sequence_length = _assert_all_equal_and_return(sequence_lengths)
133-
134-
concat_result = array_ops.concat(output_tensors, -1)
135-
ops.add_to_collection(ops.GraphKeys.ASYNC_EMBEDDING_OUTPUT_TENSORS, concat_result)
136-
return concat_result, sequence_length
41+
features, feature_columns, weight_collections=None, trainable=True):
42+
""" "Builds input layer for sequence input.
43+
44+
All `feature_columns` must be sequence dense columns with the same
45+
`sequence_length`. The output of this method can be fed into sequence
46+
networks, such as RNN.
47+
48+
The output of this method is a 3D `Tensor` of shape `[batch_size, T, D]`.
49+
`T` is the maximum sequence length for this batch, which could differ from
50+
batch to batch.
51+
52+
If multiple `feature_columns` are given with `Di` `num_elements` each, their
53+
outputs are concatenated. So, the final `Tensor` has shape
54+
`[batch_size, T, D0 + D1 + ... + Dn]`.
55+
56+
Example:
57+
58+
```python
59+
rating = sequence_numeric_column('rating')
60+
watches = sequence_categorical_column_with_identity(
61+
'watches', num_buckets=1000)
62+
watches_embedding = embedding_column(watches, dimension=10)
63+
columns = [rating, watches]
64+
65+
features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
66+
input_layer, sequence_length = sequence_input_layer(features, columns)
67+
68+
rnn_cell = tf.compat.v1.nn.rnn_cell.BasicRNNCell(hidden_size)
69+
outputs, state = tf.compat.v1.nn.dynamic_rnn(
70+
rnn_cell, inputs=input_layer, sequence_length=sequence_length)
71+
```
72+
73+
Args:
74+
features: A dict mapping keys to tensors.
75+
feature_columns: An iterable of dense sequence columns. Valid columns are
76+
- `embedding_column` that wraps a `sequence_categorical_column_with_*`
77+
- `sequence_numeric_column`.
78+
weight_collections: A list of collection names to which the Variable will be
79+
added. Note that variables will also be added to collections
80+
`tf.GraphKeys.GLOBAL_VARIABLES` and `ops.GraphKeys.MODEL_VARIABLES`.
81+
trainable: If `True` also add the variable to the graph collection
82+
`GraphKeys.TRAINABLE_VARIABLES`.
83+
84+
Returns:
85+
An `(input_layer, sequence_length)` tuple where:
86+
- input_layer: A float `Tensor` of shape `[batch_size, T, D]`.
87+
`T` is the maximum sequence length for this batch, which could differ
88+
from batch to batch. `D` is the sum of `num_elements` for all
89+
`feature_columns`.
90+
- sequence_length: An int `Tensor` of shape `[batch_size]`. The sequence
91+
length for each example.
92+
93+
Raises:
94+
ValueError: If any of the `feature_columns` is the wrong type.
95+
"""
96+
feature_columns = fc._normalize_feature_columns(feature_columns)
97+
for c in feature_columns:
98+
if not isinstance(c, fc._SequenceDenseColumn):
99+
raise ValueError(
100+
"All feature_columns must be of type _SequenceDenseColumn. "
101+
"You can wrap a sequence_categorical_column with an embedding_column "
102+
"or indicator_column. "
103+
"Given (type {}): {}".format(type(c), c)
104+
)
105+
106+
with variable_scope.variable_scope(
107+
None, default_name="sequence_input_layer", values=features.values()
108+
):
109+
builder = fc._LazyBuilder(features)
110+
output_tensors = []
111+
sequence_lengths = []
112+
ordered_columns = []
113+
group_name_set = set()
114+
group_embedding_list = []
115+
embedding_columns = []
116+
117+
for index, column in enumerate(sorted(feature_columns, key=lambda x: x.name)):
118+
group_name = getattr(column, "group_name", "")
119+
ordered_columns.append(column)
120+
with variable_scope.variable_scope(
121+
None, default_name=column._var_scope_name
122+
):
123+
if group_name != "":
124+
group_name_set.add(group_name)
125+
output_tensor = None
126+
output_tensors.append(output_tensor) # placeholder
127+
group_embedding_list.append(index)
128+
embedding_columns.append(column)
129+
sequence_lengths.append(None)
130+
else:
131+
dense_tensor, sequence_length = column._get_sequence_dense_tensor(
132+
builder,
133+
weight_collections=weight_collections,
134+
trainable=trainable,
135+
)
136+
# Flattens the final dimension to produce a 3D Tensor.
137+
num_elements = column._variable_shape.num_elements()
138+
shape = array_ops.shape(dense_tensor)
139+
target_shape = [shape[0], shape[1], num_elements]
140+
output_tensors.append(
141+
array_ops.reshape(dense_tensor, shape=target_shape)
142+
)
143+
sequence_lengths.append(sequence_length)
144+
145+
group_embedding_tensor = gec._get_global_group_embedding_scope(
146+
group_name_set, builder, weight_collections, trainable
147+
)
148+
for ind, column in zip(group_embedding_list, embedding_columns):
149+
output_tensor, sequence_length = group_embedding_tensor[column]
150+
output_tensors[ind] = output_tensor
151+
sequence_lengths[ind] = sequence_length
152+
153+
fc._verify_static_batch_size_equality(output_tensors, ordered_columns)
154+
fc._verify_static_batch_size_equality(sequence_lengths, ordered_columns)
155+
sequence_length = _assert_all_equal_and_return(sequence_lengths)
156+
157+
concat_result = array_ops.concat(output_tensors, -1)
158+
ops.add_to_collection(
159+
ops.GraphKeys.ASYNC_EMBEDDING_OUTPUT_TENSORS, concat_result
160+
)
161+
return concat_result, sequence_length
137162

138163

139164
def concatenate_context_input(context_input, sequence_input):

tensorflow/core/kernels/group_embedding/group_embedding_lookup_sparse_forward_base_ops.cu.h

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,13 @@ __global__ void SetToIntMaxSTG128(const int batch_size, int* values_offset) {
4949
}
5050
}
5151

52-
__global__ void CalcPerElementRowOffset(const int batch_size, const int64_t nnz,
53-
const int64_t* indices,
52+
__global__ void CalcPerElementRowOffset(int batch_size, int nnz,
53+
int stride, const int64_t* indices,
5454
volatile int* values_offset) {
5555
const int thread_offset = blockIdx.x * blockDim.x + threadIdx.x;
5656
const int int_max = 0x7fffffff;
57-
if (thread_offset < int(nnz)) {
58-
const int64_t element_row = indices[thread_offset];
57+
if (thread_offset < nnz) {
58+
const int64_t element_row = indices[stride*thread_offset];
5959
atomicMin((int*)values_offset + int(element_row), thread_offset);
6060
__syncthreads();
6161
if (thread_offset < int(batch_size - 1)) {
@@ -67,7 +67,7 @@ __global__ void CalcPerElementRowOffset(const int batch_size, const int64_t nnz,
6767
}
6868
}
6969

70-
inline void launch_cal_per_element_row_offset(const int batch_size, int nnz,
70+
inline void launch_cal_per_element_row_offset(const int batch_size, int nnz, int stride,
7171
const int64_t* sp_indices,
7272
int* offset_indices,
7373
cudaStream_t stream) {
@@ -78,7 +78,7 @@ inline void launch_cal_per_element_row_offset(const int batch_size, int nnz,
7878

7979
blocks = (nnz - 1) / threads + 1;
8080
CalcPerElementRowOffset<<<blocks, threads, 0, stream>>>(
81-
batch_size, nnz, sp_indices, offset_indices);
81+
batch_size, nnz, stride, sp_indices, offset_indices);
8282
}
8383

8484
template <typename TKey, typename TValue, Combiner combiner, int Tilesize>
@@ -569,6 +569,7 @@ class GroupEmbeddingLookupForwardBaseOp : public OpKernel {
569569
OP_REQUIRES_OK(c, c->GetAttr("dimension", &dimension_));
570570
OP_REQUIRES_OK(c, c->GetAttr("max_norm", &max_norm_));
571571
OP_REQUIRES_OK(c, c->GetAttr("ignore_weights", &ignore_weights_));
572+
OP_REQUIRES_OK(c, c->GetAttr("is_sequence", &is_sequence_));
572573
lookuper_.initialize(num_lookups_, dimension_, max_norm_);
573574
}
574575

@@ -677,10 +678,11 @@ class GroupEmbeddingLookupForwardBaseOp : public OpKernel {
677678
int num_lookups_;
678679
int dimension_;
679680
bool ignore_weights_;
681+
bool is_sequence_;
680682
};
681683

682684
} // namespace
683685

684686
} // namespace tensorflow
685687

686-
#endif // GOOGLE_CUDA
688+
#endif // GOOGLE_CUDA

tensorflow/core/kernels/group_embedding/group_embedding_lookup_sparse_forward_base_ops.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class GroupLookupBaseCpuOp : public OpKernel {
2727
OP_REQUIRES_OK(c, c->GetAttr("dimension", &m_dimension));
2828
// OP_REQUIRES_OK(c, c->GetAttr("max_norm", &max_norm_));
2929
OP_REQUIRES_OK(c, c->GetAttr("ignore_weights", &m_ignore_weights));
30-
30+
OP_REQUIRES_OK(c, c->GetAttr("is_sequence", &m_is_sequence));
3131
OP_REQUIRES_OK(c, ReadInt64FromEnvVar(kUniqueOpPartitionSizeEnv,
3232
kPartitionSize, &partition_size_));
3333
OP_REQUIRES(
@@ -52,6 +52,7 @@ class GroupLookupBaseCpuOp : public OpKernel {
5252
int m_dimension;
5353
bool m_is_use_default_value_tensor;
5454
bool m_ignore_weights;
55+
bool m_is_sequence;
5556
std::string m_combiner;
5657
bool serial_ = false;
5758
int64 partition_size_ = 0;

0 commit comments

Comments
 (0)