-
Hi, I wonder if anyone can advise on this please? I run the following code: import jax
A = 2
B = None
C = None
def branching_fn1(A, B = None, C = None):
def positive_branch1(A):
return A
def negative_branch1(A):
def branching_fn2(A, B, C):
return jax.lax.cond((C is None), positive_branch2, negative_branch2, A)
def positive_branch2(A):
A = A + B # rewrite this in jax primitives
return A
def negative_branch2(A):
A = A + (C*B) # rewrite this in jax primitives
return A
return branching_fn2(A, B, C)
return jax.lax.cond((B is None), positive_branch1, negative_branch1, A)
output = branching_fn1(A, B, C)
print(output) And get variations on the error: unsupported operand type(s) for +: 'DynamicJaxprTracer' and 'NoneType' - regardless of if the conditional path should be taken, given the starting data. I note that if I run the code below it works as I would expect. Since I'm trying to convert numpy code to jax, and it already uses "None", understanding how to resolve this would be really helpful. B & C in the original code may be None or vector/scalar - In the working code it is possible that I could remove the first conditional, and have two versions of jit'd functions dispatched with a pythoninc "if", but that doesn't help with the second one. And help/advice/links would me most appreciated. I am still very new to Jax, so please don't be afraid to explain it to me like I'm a small child. Thanks in advance! import jax
A = 2
B = -1
C = -1
def branching_fn1(A, B = -1, C = -1):
def positive_branch1(A):
return A
def negative_branch1(A):
def branching_fn2(A, B, C):
return jax.lax.cond((C == -1), positive_branch2, negative_branch2, A)
def positive_branch2(A):
A = A + B # rewrite this in jax primitives
return A
def negative_branch2(A):
A = A + (C*B) # rewrite this in jax primitives
return A
return branching_fn2(A, B, C)
return jax.lax.cond((B == -1), positive_branch1, negative_branch1, A)
output = branching_fn1(A, B, C)
print(output) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
The problem is that But even so, I have trouble following the logic of the first code-block: for example, it seems that if Another thing you could consider: if you're branching on python identity comparisons (as opposed to value comparisons), instead of using
|
Beta Was this translation helpful? Give feedback.
The problem is that
lax.cond
is not a short-circuiting operation: in order to push the computation down to XLA, it traces both thetrue_fun
andfalse_fun
to construct an intermediate representation, even if the final code will not need to evaluate one of the branches.But even so, I have trouble following the logic of the first code-block: for example, it seems that if
C is not None
andA = None
, then even with short-circuiting logicpositive_branch2
would attempt to executeA = A + B
, which would be ill-defined even without respect to JAX.Another thing you could consider: if you're branching on python identity comparisons (as opposed to value comparisons), instead of using
lax.cond
you …