Skip to content

Commit d0bb822

Browse files
authored
Support image and audio information in task summaries (#1819)
1 parent 6d19bb3 commit d0bb822

File tree

5 files changed

+54
-15
lines changed

5 files changed

+54
-15
lines changed

keras_nlp/src/layers/preprocessing/audio_converter.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ class AudioConverter(PreprocessingLayer):
4949

5050
backbone_cls = None
5151

52+
def audio_shape(self):
53+
"""Returns the preprocessed size of a single audio sample."""
54+
return (None,)
55+
5256
@classproperty
5357
def presets(cls):
5458
"""List built-in presets for an `AudioConverter` subclass."""

keras_nlp/src/layers/preprocessing/image_converter.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ class ImageConverter(PreprocessingLayer):
5252

5353
backbone_cls = None
5454

55+
def image_size(self):
56+
"""Returns the default size of a single image."""
57+
return (None, None)
58+
5559
@classproperty
5660
def presets(cls):
5761
"""List built-in presets for an `ImageConverter` subclass."""

keras_nlp/src/layers/preprocessing/resizing_image_converter.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,17 @@ def __init__(
7373
# By default, we just do a simple resize. Any model can subclass this
7474
# layer for preprocessing of a raw image to a model image input.
7575
self.resizing = keras.layers.Resizing(
76-
height,
77-
width,
76+
height=height,
77+
width=width,
7878
crop_to_aspect_ratio=crop_to_aspect_ratio,
7979
interpolation=interpolation,
8080
data_format=data_format,
8181
)
8282

83+
def image_size(self):
84+
"""Returns the preprocessed size of a single image."""
85+
return (self.resizing.height, self.resizing.width)
86+
8387
@preprocessing_function
8488
def call(self, inputs):
8589
return self.resizing(inputs)

keras_nlp/src/models/task.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -293,14 +293,20 @@ def summary(
293293
print_fn = print_msg
294294

295295
def highlight_number(x):
296-
return f"[color(45)]{x}[/]" if x is None else f"[color(34)]{x}[/]"
296+
if x is None:
297+
f"[color(45)]{x}[/]"
298+
return f"[color(34)]{x:,}[/]" # Format number with commas.
297299

298300
def highlight_symbol(x):
299301
return f"[color(33)]{x}[/]"
300302

301303
def bold_text(x):
302304
return f"[bold]{x}[/]"
303305

306+
def highlight_shape(shape):
307+
highlighted = [highlight_number(x) for x in shape]
308+
return "(" + ", ".join(highlighted) + ")"
309+
304310
if self.preprocessor:
305311
# Create a rich console for printing. Capture for non-interactive logging.
306312
if print_fn:
@@ -312,27 +318,44 @@ def bold_text(x):
312318
console = rich_console.Console(highlight=False)
313319

314320
column_1 = rich_table.Column(
315-
"Tokenizer (type)",
321+
"Layer (type)",
316322
justify="left",
317-
width=int(0.5 * line_length),
323+
width=int(0.6 * line_length),
318324
)
319325
column_2 = rich_table.Column(
320-
"Vocab #",
326+
"Config",
321327
justify="right",
322-
width=int(0.5 * line_length),
328+
width=int(0.4 * line_length),
323329
)
324330
table = rich_table.Table(
325331
column_1, column_2, width=line_length, show_lines=True
326332
)
333+
334+
def add_layer(layer, info):
335+
layer_name = markup.escape(layer.name)
336+
layer_class = highlight_symbol(
337+
markup.escape(layer.__class__.__name__)
338+
)
339+
table.add_row(
340+
f"{layer_name} ({layer_class})",
341+
info,
342+
)
343+
327344
tokenizer = self.preprocessor.tokenizer
328-
tokenizer_name = markup.escape(tokenizer.name)
329-
tokenizer_class = highlight_symbol(
330-
markup.escape(tokenizer.__class__.__name__)
331-
)
332-
table.add_row(
333-
f"{tokenizer_name} ({tokenizer_class})",
334-
highlight_number(f"{tokenizer.vocabulary_size():,}"),
335-
)
345+
if tokenizer:
346+
info = "Vocab size: "
347+
info += highlight_number(tokenizer.vocabulary_size())
348+
add_layer(tokenizer, info)
349+
image_converter = self.preprocessor.image_converter
350+
if image_converter:
351+
info = "Image size: "
352+
info += highlight_shape(image_converter.image_size())
353+
add_layer(image_converter, info)
354+
audio_converter = self.preprocessor.audio_converter
355+
if audio_converter:
356+
info = "Audio shape: "
357+
info += highlight_shape(audio_converter.audio_shape())
358+
add_layer(audio_converter, info)
336359

337360
# Print the to the console.
338361
preprocessor_name = markup.escape(self.preprocessor.name)

keras_nlp/src/models/whisper/whisper_audio_converter.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,10 @@ def __init__(
9595
# `(num_fft_bins // 2 + 1, num_mels).`
9696
self.mel_filters = self._get_mel_filters()
9797

98+
def audio_shape(self):
99+
"""Returns the preprocessed size of a single audio sample."""
100+
return (self.max_audio_length, self.num_mels)
101+
98102
def _get_mel_filters(self):
99103
"""
100104
Adapted from Hugging Face

0 commit comments

Comments
 (0)