Skip to content

User-facing API like vmapΒ #117

@jessebett

Description

@jessebett

@ChrisRackauckas suggests that this package provides much of the utilities that would make broadcasting over specified axes efficient. This can be seen in DiffEqGPU.jl.

Can we discuss a user facing API so we can directly compare against JAX vmap.

For instance if I have a function

f(x::Scalar, y::Vector, A::Array) = linalg...

How can I efficiently broadcast over collections of inputs stored in collections with axes like multidimensional arrays ("tensors").

# Broadcast over rows of second argument
vmap(f, in_axes=(nothing, 1, nothing))(scalar, array, array)

# Broadcast over axes for all arguments
vmap(f, in_axes=(1, 1, 3))(vector, array, tensor)

Further, is it possible to provide these as defaults for something like eachslice so that broadcasting Just Works?

f.(scalar, eachrow(array), array)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions