How to pass namedtuples and array (made hashable) to a function as static arguments? #8038
-
This is a direct reference to the amazing answer by @mattjj to the question of how to make arrays hashable static arguments. In my case, we need to pass a bunch of arguments to the function. Some of them are static and some traced. The way we were doing this was to make namedtuples to create sort of a dictionary and pass them. I'm finding it hard to figure out how to pass both a namedtuple and an array (when made hashable) through the same function. I tried the following (note a slight change in the arguments as compared to the answer by @mattjj).
I find that the calls to function
I think this is because it is trying to index a namedtuple and failing? How do I get around this problem? P.S: The motivation behind this is that for my problem, I know the sizes of matrices that will be generated apriori. So, in order for the dimensions to NOT be traced values, we are trying to pre-compute the dimensions and pass the array containing the dimensions. This is where we were trying to make the array (containing the dimensions of arrays to be created in the jitted function) as a static-argument. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
The error is because you creating a |
Beta Was this translation helpful? Give feedback.
The error is because you creating a
HashableArrayWrapper
of a namedtuplent
, which then later callssum(nt(1))
which then fails becausent
is of course not a valid input to a jax operation