Difficulties converting with jax2tf and polymorphism #15974
-
Hello, I'm new to Jax and struggling to convert a Jax model (MAXIM) to Tensorflow.js. Converting the model with a variable batch dimension and fixed image dimensions works well: tfjs.converters.convert_jax(
jax_model.apply,
params,
model_dir='./tmp',
input_signatures=[tf.TensorSpec([None, 256, 256, 3], tf.float32)],
polymorphic_shapes=["(b, 256, 256, 3)"],
) However, for certain models fixed input sizes result in a noticeable patching effect. Therefore, I'm trying to update the model to support dynamic image inputs. Changing the conversion code to something like: tfjs.converters.convert_jax(
jax_model.apply,
params,
model_dir='./tmp',
input_signatures=[tf.TensorSpec([None, None, 256, 3], tf.float32)],
polymorphic_shapes=["(b, h, 256, 3)"],
) Fails. I've refactored the library's Jax code a bit in order to avoid the use of einops, so I'm not entirely sure how to share a runnable code snippet, but I can share the errors I'm seeing. The code fails at: reshaped_padded = np.reshape(x_pad, new_shape) With the values of
This code fails an assertion here. When I dig in, Can anyone point me in the right direction? I'm happy to provide more information if it's helpful. UPDATE 1 Ok, I was able to whittle it down to the offending code. But I don't know why this fails or where in the docs to look for more information. Would appreciate any pointers! This fails: from jax.experimental.jax2tf.shape_poly import _DimExpr
h = _DimExpr.from_var('h')
r = h // (h // 16)
r.bounds() Because this code returns a (h // 16).bounds()
UPDATE 2 So I continue digging! from jax.experimental.jax2tf.shape_poly import _DimAtom, _DimExpr
h = _DimExpr.from_var('h')
r = h // 16
for mon, coeff in r.monomials():
for a, exp in mon.items():
for opnd in a.operands:
print('Bounds for DimExpr', opnd.bounds())
print('Bounds for DimAtom', a.bounds()) Produces:
And what appears to be happening is that, on this line, the result of I further presume that the value of So I think that leads me to ask: is there a way, when specifying dynamic variables, to specify a minimum value that a variable can be? E.g., is there a way to specify in the polymorphism settings that a value should be in the range of UPDATE 3 I learned from this that you can specify a dimension variable to be a multiple:
Which is great and exactly what I was looking for. That leads me to a new error. The model that I'm converting has an Upsample layer that attempts to cast the height and width to concrete values. I'm trying to make this work with dynamic values. I've tried the following: print(int(h * self.ratio)) yields:
Fair enough. So I then try: print((h * self.ratio).astype(int)) Which yields:
However, when I try to pass that to x = jax.image.resize(
x,
shape=(n, h, w, c),
method="bilinear", I get:
The readme indicates that this error:
Does this mean that |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
The thrust of the original question was, is there a way to specify that a dynamic shape has a range? The answer is, sort of, you can specify that a dynamic variable can be a multiple:
|
Beta Was this translation helpful? Give feedback.
This discussion is getting a bit long, so I've carved out the latest question into its own discussion.
The thrust of the original question was, is there a way to specify that a dynamic shape has a range? The answer is, sort of, you can specify that a dynamic variable can be a multiple: