diff --git a/pytensor/link/jax/linker.py b/pytensor/link/jax/linker.py index eb2f4fb267..300f2f7323 100644 --- a/pytensor/link/jax/linker.py +++ b/pytensor/link/jax/linker.py @@ -117,10 +117,8 @@ def create_thunk_inputs(self, storage_map): for n in self.fgraph.inputs: sinput = storage_map[n] if isinstance(sinput[0], Generator): - new_value = jax_typify( - sinput[0], dtype=getattr(sinput[0], "dtype", None) - ) - sinput[0] = new_value + # Neet to convert Generator into JAX PRNGkey + sinput[0] = jax_typify(sinput[0]) thunk_inputs.append(sinput) return thunk_inputs