How to access the default is_leaf function #17364
Replies: 2 comments 3 replies
-
No, there is no default If you're curious about how pytree flattening is implemented, you can find the implementation here: https://github.com/tensorflow/tensorflow/blob/949b63fa788411e4ddba84bbd70385aedd4ed4a2/tensorflow/compiler/xla/python/pytree.cc#L186 (Here Roughly, the logic here is:
|
Beta Was this translation helpful? Give feedback.
-
I understand that. But in my function, I want to call the default If you want to understand the reason, I just want the functions to perform as normal, except treating def is_leaf(x):
if x is None:
return True
return default_is_leaf(x) My problem is that I don't know how to call |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Many functions in
jax.tree_util
allow a user-specifiedis_leaf
function. I'd like to override the default behavior in only a few cases. So it would be convenient to call the defaultis_leaf
function. But how can I do that?Beta Was this translation helpful? Give feedback.
All reactions