Skip to content

Support arbitrary Python iterables in preprocessing layer adapt() methods #22156

@goyaladitya05

Description

@goyaladitya05

Currently, most preprocessing layers (FeatureSpace, IndexLookup, TextVectorization, Discretization) only accept tf.data.Dataset objects in their adapt() methods. This creates friction for users working with JAX or PyTorch backends who typically have data in other formats like lists, generators, or custom iterables.

Current Limitation

# This works
import tensorflow as tf
dataset = tf.data.Dataset.from_tensor_slices(data)
layer.adapt(dataset)

# This doesn't work
layer.adapt(["apple", "banana", "cherry"])  # Error

Proposed Solution

Enable preprocessing layers to accept arbitrary Python iterables (lists, generators, custom iterables) in their adapt() methods, similar to how Normalization already works.

Expected Behavior

# Lists should work
layer.adapt(["apple", "banana", "cherry"])  

# Generators should work (memory efficient!)
def data_generator():
    for item in huge_dataset:
        yield item
layer.adapt(data_generator())  

# Custom iterables should work
layer.adapt(custom_iterable_object)  

Benefits

  1. Backend Agnostic: JAX and PyTorch users can use native Python data structures without converting to TensorFlow datasets
  2. Memory Efficient: Generators allow processing large datasets without loading everything into memory
  3. Consistent API: All preprocessing layers would behave consistently with Normalization
  4. Backward Compatible: Existing code using tf.data.Dataset continues to work
  5. Pythonic: More natural for users coming from standard Python/NumPy workflows

Affected Layers

  • FeatureSpace
  • StringLookup / IntegerLookup (via IndexLookup)
  • TextVectorization
  • Discretization

Use Cases

1. Processing large text files with generators

def text_generator():
    with open("large_corpus.txt") as f:
        for line in f:
            yield line.strip()

vectorizer = TextVectorization()
vectorizer.adapt(text_generator(), batch_size=64)

2. JAX/PyTorch users with native data structures

# JAX users with lists
data = ["text1", "text2", "text3"]
layer.adapt(data)

# PyTorch users with custom iterables
layer.adapt(pytorch_dataloader)

3. Quick prototyping and testing

# Simple list for quick testing
lookup = StringLookup()
lookup.adapt(["cat", "dog", "bird"])

Related

This aligns with Keras 3's goal of being backend-agnostic and improving multi-backend compatibility. The idea for this came as talked about in //issues/18442.

Metadata

Metadata

Assignees

No one assigned

    Labels

    type:featureThe user is asking for a new feature.

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions