Using tracers as dictionary keys #10560
-
Is there any solution to using tracers as dictionary keys and allowing their hash to be equivalent to their runtime values? For example: import jax
from typing import Dict
sample_dict = {"agent_0": 0, "agent_1": 1, "agent_2": 2, "agent_3": 3}
def unjitted_fun(example_id: int, dictionary: Dict[str, int]):
example_key = f"agent_{example_id}"
example_output = dictionary[example_key]
return example_output
for i in range(4):
example_output = unjitted_fun(i, sample_dict)
print(example_output) This outputs:
When this code is jitted i.e: jitted_fun = jax.jit(unjitted_fun)
for i in range(4):
example_output = jitted_fun(i, sample_dict)
print(example_output) This doesn't work due to a key error as follows:
Any help would be appreciated. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 6 replies
-
The short answer is no: there is no way to make the hash of a tracer match its runtime value, because a tracer is an abstract representation of all possible runtime values. I suspect this is an XY problem – if you can give a more detailed example of what you're trying to accomplish, we may be able to recommend a viable approach. |
Beta Was this translation helpful? Give feedback.
The short answer is no: there is no way to make the hash of a tracer match its runtime value, because a tracer is an abstract representation of all possible runtime values.
I suspect this is an XY problem – if you can give a more detailed example of what you're trying to accomplish, we may be able to recommend a viable approach.