TypeVar compatibility with vmap
or scan
#16004
Unanswered
joeryjoery
asked this question in
Q&A
Replies: 1 comment
-
There is no generally applicable annotation in this case, because def sample_fun(x: T) -> T:
return x
def batch_fun(xs: T) -> T:
return jax.vmap(sample_fun)(xs) Of course, if you know that the input def sample_fun(x: Array) -> Array:
return x
def batch_fun(xs: Array) -> Array:
return jax.vmap(sample_fun)(xs) From the static typing perspective, there is no way to express that an array is |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi I'm wondering if there is a way to type-annotate a TypeVar with a batch dimension.
Something that would make it compatible with
vmap
orscan
of the sorts.To give an example, I'm looking for something in place of
Batched
,I have tried to do this with e.g.,
typing.Sequence
, but my Type Checker is complaining thatArray
is not aSequence
.Any hints would be greatly appreciated :)
Beta Was this translation helpful? Give feedback.
All reactions