Skip to content

Segfualt with prefetching and MLX arrays in key transform #47

@awni

Description

@awni

The following code segfaults on my machine (M1 Max, OS 14.2)

Some observations:

  • Using NumPy in place of MLX works fine
  • Only segfaults with prefetching
import mlx.core as mx
from mlx.data.datasets import load_cifar10

def get_cifar10(batch_size, root=None):
    tr = load_cifar10(root=root)

    mean = mx.array([0.485, 0.456, 0.406]).reshape((1, 1, 3))
    std = mx.array([0.229, 0.224, 0.225]).reshape((1, 1, 3))

    def normalize(x):
        x = x.astype("float32") / 255.0
        return (x - mean) / std

    tr_iter = (
        tr.shuffle()
        .to_stream()
        .image_random_h_flip("image", prob=0.5)
        .pad("image", 0, 4, 4, 0.0)
        .pad("image", 1, 4, 4, 0.0)
        .image_random_crop("image", 32, 32)
        .key_transform("image", normalize)
        .batch(batch_size)
        .prefetch(4, 4)
    )

    return tr_iter

if __name__ == "__main__":
    tr_iter = get_cifar10(256)
    for batch_counter, batch in enumerate(tr_iter):
        print(batch)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions