Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 38 additions & 93 deletions keras/src/layers/preprocessing/index_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,7 @@ def __init__(
# Remember original `vocabulary` as `input_vocabulary` for serialization
# via `get_config`. However, if `vocabulary` is a file path or a URL, we
# serialize the vocabulary as an asset and clear the original path/URL.
self.input_vocabulary = (
vocabulary if not isinstance(vocabulary, str) else None
)
self.input_vocabulary = vocabulary if not isinstance(vocabulary, str) else None
self.input_idf_weights = idf_weights

# We set this hidden attr to
Expand All @@ -231,9 +229,7 @@ def __init__(
# Masks should map to 0 for int output and be dropped otherwise. Max
# ints will be dropped from the bincount op.
mask_value = (
0
if self.output_mode == "int"
else tf.as_dtype(self._value_dtype).max
0 if self.output_mode == "int" else tf.as_dtype(self._value_dtype).max
)
if self.num_oov_indices == 0:
# If there are no OOV indices, we map OOV tokens to -1 and error
Expand All @@ -251,9 +247,7 @@ def __init__(
self._default_value = -1
if self.mask_token is not None:
self._mask_key = tf.convert_to_tensor(mask_key, self._key_dtype)
self._mask_value = tf.convert_to_tensor(
mask_value, self._value_dtype
)
self._mask_value = tf.convert_to_tensor(mask_value, self._value_dtype)

if self.output_mode == "tf_idf":
if self._has_input_vocabulary and idf_weights is None:
Expand Down Expand Up @@ -288,16 +282,12 @@ def __init__(
default_value=0,
)
if self.output_mode == "tf_idf":
self.token_document_counts = (
tf.lookup.experimental.MutableHashTable(
key_dtype=vocabulary_dtype,
value_dtype="int64",
default_value=0,
)
)
self.num_documents = tf.Variable(
0, dtype="int64", trainable=False
self.token_document_counts = tf.lookup.experimental.MutableHashTable(
key_dtype=vocabulary_dtype,
value_dtype="int64",
default_value=0,
)
self.num_documents = tf.Variable(0, dtype="int64", trainable=False)

def get_vocabulary(self, include_special_tokens=True):
"""Returns the current vocabulary of the layer.
Expand All @@ -322,18 +312,14 @@ def get_vocabulary(self, include_special_tokens=True):
self._tensor_vocab_to_numpy(vocab),
indices.numpy(),
)
lookup = collections.defaultdict(
lambda: self.oov_token, zip(indices, vocab)
)
lookup = collections.defaultdict(lambda: self.oov_token, zip(indices, vocab))
vocab = [lookup[x] for x in range(self.vocabulary_size())]
if self.mask_token is not None and self.output_mode == "int":
vocab[0] = self.mask_token
if not include_special_tokens:
vocab = vocab[self._token_start_index() :]
if self.vocabulary_dtype == "string":
return [
i.decode("utf-8") if isinstance(i, bytes) else i for i in vocab
]
return [i.decode("utf-8") if isinstance(i, bytes) else i for i in vocab]
else:
return vocab

Expand All @@ -345,10 +331,7 @@ def vocabulary_size(self):
indices.
"""
if tf.executing_eagerly():
return (
int(self.lookup_table.size().numpy())
+ self._token_start_index()
)
return int(self.lookup_table.size().numpy()) + self._token_start_index()
else:
return self.lookup_table.size() + self._token_start_index()

Expand Down Expand Up @@ -422,9 +405,7 @@ def set_vocabulary(self, vocabulary, idf_weights=None):
)

if not tf.io.gfile.exists(vocabulary):
raise ValueError(
f"Vocabulary file {vocabulary} does not exist."
)
raise ValueError(f"Vocabulary file {vocabulary} does not exist.")
if self.output_mode == "tf_idf":
raise ValueError(
"output_mode `'tf_idf'` does not support loading a "
Expand Down Expand Up @@ -458,18 +439,15 @@ def set_vocabulary(self, vocabulary, idf_weights=None):

if vocabulary.size == 0:
raise ValueError(
"Cannot set an empty vocabulary. "
f"Received: vocabulary={vocabulary}"
"Cannot set an empty vocabulary. " f"Received: vocabulary={vocabulary}"
)

oov_start = self._oov_start_index()
token_start = self._token_start_index()
special_tokens = [self.mask_token] * oov_start + [
self.oov_token
] * self.num_oov_indices
found_special_tokens = np.array_equal(
special_tokens, vocabulary[:token_start]
)
found_special_tokens = np.array_equal(special_tokens, vocabulary[:token_start])
if found_special_tokens:
tokens = vocabulary[token_start:]
else:
Expand All @@ -496,11 +474,7 @@ def set_vocabulary(self, vocabulary, idf_weights=None):
)
# Only error out for oov_token when invert=True. When invert=False,
# oov_token is unused during lookup.
if (
self.oov_token is not None
and self.invert
and self.oov_token in tokens
):
if self.oov_token is not None and self.invert and self.oov_token in tokens:
oov_index = np.argwhere(vocabulary == self.oov_token)[-1]
raise ValueError(
"Found reserved OOV token at unexpected location in "
Expand Down Expand Up @@ -551,9 +525,7 @@ def set_vocabulary(self, vocabulary, idf_weights=None):
# zeros as well.
back_padding_value = 0
if self.pad_to_max_tokens and self.max_tokens is not None:
back_padding = (
self.max_tokens - front_padding - len(idf_weights)
)
back_padding = self.max_tokens - front_padding - len(idf_weights)
else:
back_padding = 0
weights = np.pad(
Expand Down Expand Up @@ -586,27 +558,24 @@ def variable_dtype(self):
def compute_output_shape(self, input_shape):
if self.output_mode == "int":
return input_shape
depth = (
self.max_tokens
if self.pad_to_max_tokens
else self._frozen_vocab_size
)

depth = self._get_output_depth()
input_shape = tuple(input_shape)

if self.output_mode == "one_hot":
# One-hot encodes each element: (batch, d1, ..., dN) -> (batch, d1,
# ..., dN, depth)
if len(input_shape) > 1 and input_shape[-1] == 1:
return input_shape[:-1] + (depth,)
return input_shape + (depth,)
# multi_hot, count, tf_idf: treat last dim as sample dim, output
# (batch, ..., depth)

return input_shape[:-1] + (depth,)

def _get_output_depth(self):
return (
self.max_tokens
if self.pad_to_max_tokens and self.max_tokens is not None
else self.vocabulary_size()
)

def compute_output_spec(self, inputs):
if self.output_mode == "int":
output_dtype = "int64"
else:
output_dtype = backend.floatx()
output_dtype = "int64" if self.output_mode == "int" else backend.floatx()
output_shape = self.compute_output_shape(inputs.shape)
return backend.KerasTensor(output_shape, dtype=output_dtype)

Expand Down Expand Up @@ -643,9 +612,7 @@ def update_state(self, data):
data = tf.expand_dims(data, 0)

tokens, counts = self._num_tokens(data)
self.token_counts.insert(
tokens, counts + self.token_counts.lookup(tokens)
)
self.token_counts.insert(tokens, counts + self.token_counts.lookup(tokens))

if self.output_mode == "tf_idf":
# Dedupe each row of our dataset.
Expand All @@ -663,9 +630,7 @@ def update_state(self, data):
if isinstance(data, tf.RaggedTensor):
self.num_documents.assign_add(data.nrows())
else:
self.num_documents.assign_add(
tf.shape(data, out_type="int64")[0]
)
self.num_documents.assign_add(tf.shape(data, out_type="int64")[0])

def finalize_state(self):
if self._has_input_vocabulary or tf.equal(self.token_counts.size(), 0):
Expand Down Expand Up @@ -738,9 +703,7 @@ def reset_state(self):

self.token_counts.remove(self.token_counts.export()[0])
if self.output_mode == "tf_idf":
self.token_document_counts.remove(
self.token_document_counts.export()[0]
)
self.token_document_counts.remove(self.token_document_counts.export()[0])
self.num_documents.assign(0)

def call(self, inputs):
Expand Down Expand Up @@ -771,19 +734,11 @@ def call(self, inputs):
lookups = tf.squeeze(lookups, -1)
return lookups

depth = (
self.max_tokens
if self.pad_to_max_tokens
else self._frozen_vocab_size
)
idf_weights = (
self.idf_weights_const if self.output_mode == "tf_idf" else None
)
depth = self.max_tokens if self.pad_to_max_tokens else self._frozen_vocab_size
idf_weights = self.idf_weights_const if self.output_mode == "tf_idf" else None
output = numerical_utils.encode_categorical_inputs(
lookups,
output_mode=(
"count" if self.output_mode == "tf_idf" else self.output_mode
),
output_mode=("count" if self.output_mode == "tf_idf" else self.output_mode),
depth=depth,
dtype=self._value_dtype,
sparse=self.sparse,
Expand Down Expand Up @@ -900,22 +855,16 @@ def load_assets(self, dir_path):

def _uninitialized_lookup_table(self):
with tf.init_scope():
initializer = get_null_initializer(
self._key_dtype, self._value_dtype
)
initializer = get_null_initializer(self._key_dtype, self._value_dtype)
return tf.lookup.StaticHashTable(initializer, self._default_value)

def _lookup_table_from_tokens(self, tokens):
with tf.init_scope():
token_start = self._token_start_index()
token_end = token_start + tf.size(tokens)
indices_dtype = (
self._key_dtype if self.invert else self._value_dtype
)
indices_dtype = self._key_dtype if self.invert else self._value_dtype
indices = tf.range(token_start, token_end, dtype=indices_dtype)
keys, values = (
(indices, tokens) if self.invert else (tokens, indices)
)
keys, values = (indices, tokens) if self.invert else (tokens, indices)
initializer = tf.lookup.KeyValueTensorInitializer(
keys, values, self._key_dtype, self._value_dtype
)
Expand Down Expand Up @@ -949,11 +898,7 @@ def _expand_dims(self, inputs, axis):
return tf.expand_dims(inputs, axis)

def _oov_start_index(self):
return (
1
if self.mask_token is not None and self.output_mode == "int"
else 0
)
return 1 if self.mask_token is not None and self.output_mode == "int" else 0

def _token_start_index(self):
return self._oov_start_index() + self.num_oov_indices
Expand Down
Loading
Loading