Skip to content

Building Set Transformer with dropout=0 fails with JAX backend #206

@maltelueken

Description

@maltelueken

When trying to build the SetTransformer with dropout=0 within a ContinuousApproximator and the JAX backend, I get the following error:

RuntimeError: Unable to automatically build the model. Please build it yourself before calling fit/evaluate/predict. A model is 'built' when its variables have been created and its `self.built` attribute is True. Usually, calling the model on a batch of data is the right way to build it.
Exception encountered:
'Exception encountered when calling ContinuousApproximator.call().

Model ContinuousApproximator does not have a `call()` method implemented.

Arguments received by ContinuousApproximator.call():
  • args=({'inference_conditions': 'jnp.ndarray(shape=(64, 1), dtype=float32)', 'inference_variables': 'jnp.ndarray(shape=(64, 5), dtype=float32)', 'summary_variables': 'jnp.ndarray(shape=(64, 250, 2), dtype=float32)'},)
  • kwargs=<class 'inspect._empty'>'

Happy to provide more details if needed.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions