Skip to content
Discussion options

You must be logged in to vote

Yes, so if you say out_shape=(jax.ShapeDtypeStruct(...), jax.ShapeDtypeStruct(...)), your kernel will be passed in two output references:

def kernel(..., out_ref1, out_ref2):
  ...

Replies: 1 comment 5 replies

Comment options

You must be logged in to vote
5 replies
@bsaoptima
Comment options

@sharadmv
Comment options

sharadmv Mar 6, 2024
Collaborator

@bsaoptima
Comment options

@sharadmv
Comment options

sharadmv Mar 6, 2024
Collaborator

Answer selected by bsaoptima
@bsaoptima
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants