From 49e4780fc36a522ddfa9e02816ef181bc787ab9c Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 12 Jun 2025 11:35:03 +0200 Subject: [PATCH] Simplify typify of Generators in JAXLinker --- pytensor/link/jax/linker.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pytensor/link/jax/linker.py b/pytensor/link/jax/linker.py index eb2f4fb267..300f2f7323 100644 --- a/pytensor/link/jax/linker.py +++ b/pytensor/link/jax/linker.py @@ -117,10 +117,8 @@ def create_thunk_inputs(self, storage_map): for n in self.fgraph.inputs: sinput = storage_map[n] if isinstance(sinput[0], Generator): - new_value = jax_typify( - sinput[0], dtype=getattr(sinput[0], "dtype", None) - ) - sinput[0] = new_value + # Neet to convert Generator into JAX PRNGkey + sinput[0] = jax_typify(sinput[0]) thunk_inputs.append(sinput) return thunk_inputs