-
Notifications
You must be signed in to change notification settings - Fork 77
Open
Description
@ChrisRackauckas suggests that this package provides much of the utilities that would make broadcasting over specified axes efficient. This can be seen in DiffEqGPU.jl.
Can we discuss a user facing API so we can directly compare against JAX vmap.
For instance if I have a function
f(x::Scalar, y::Vector, A::Array) = linalg...
How can I efficiently broadcast over collections of inputs stored in collections with axes like multidimensional arrays ("tensors").
# Broadcast over rows of second argument
vmap(f, in_axes=(nothing, 1, nothing))(scalar, array, array)
# Broadcast over axes for all arguments
vmap(f, in_axes=(1, 1, 3))(vector, array, tensor)
Further, is it possible to provide these as defaults for something like eachslice
so that broadcasting Just Works?
f.(scalar, eachrow(array), array)
oscardssmith and ToucheSir
Metadata
Metadata
Assignees
Labels
No labels