Replies: 1 comment 13 replies
-
Hi - happy to help with this, but I'm not sure how to interpret your XML-like syntax for array representation. What do the fields mean? How should I interpret (Sorry if this is standard stuff... if you have docs about these representations you can point to, I may be able to read up and better understand what you're asking) |
Beta Was this translation helpful? Give feedback.
13 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi all!
I'm a developer over at the https://github.com/scikit-hep/awkward-1.0/ project.
From our docs
I recently started looking at our prototype JAX integration, and I'm looking to add the ability to compute the Jacobian of a function, e.g.
Awkward Arrays are built by composing layout nodes, so a NumPy array with shape
(4, 3, 2)
is built as:Our initial approach to integration with JAX was to represent Awkward Arrays as singular atoms with a number of leaves, i.e. the example array above would be flattened into
This means that both the input and output arrays of e.g.
func
are flattened in JAX terms into flat single buffers (in this case) alongside some auxillary data.However, to make it easier to reason about, I've mocked up a demo where our entire layout nodes can be converted into trees (not just the root node). In this demo, there is only
NumpyArray
andRegularArray
support, even though we have a wider range of node types.In the example, the array is not ragged, so the input array has shape
(4, 3, 2)
, and the output offunc(array)
should have shape(8, 3, 2)
. Therefore, the Jacobian should have shape(8, 3, 2, 4, 3, 2)
. Using our layout system, we'd expect the following structure:In the comments, a helpful key shows how the size (RHS and length LHS) are computed from each
RegularArray
node.What I'd like from JAX is to ask our tree hooks to:
(4,3,2)
) with the flat8*3*2*4*3*2=1152
Jacobian valuesI have played around with this a bit already, but it's clear to me that I'm not 100% confident on the details of what JAX is doing. Rather than guessing, I was hoping that I might get some pointers from the JAX team as to how to proceed from here.
Please feel free to follow up questions if anything needs explaining further. Thanks!
Beta Was this translation helpful? Give feedback.
All reactions