44"""
55
66from dataclasses import dataclass
7- from functools import partial
87from typing import Callable , Tuple
98
109import jax
@@ -191,7 +190,6 @@ def init( # type: ignore
191190 steps = jnp .array (0 ),
192191 )
193192
194- @partial (jax .jit , static_argnames = ("self" ,))
195193 def _compute_diversity_reward (
196194 self , transition : QDTransition , training_state : DadsTrainingState
197195 ) -> Reward :
@@ -244,8 +242,7 @@ def _compute_diversity_reward(
244242
245243 return reward
246244
247- @partial (jax .jit , static_argnames = ("self" , "env" , "deterministic" , "evaluation" ))
248- def play_step_fn (
245+ def play_step_fn ( # type: ignore
249246 self ,
250247 env_state : EnvState ,
251248 training_state : DadsTrainingState ,
@@ -339,14 +336,13 @@ def play_step_fn(
339336
340337 return next_env_state , training_state , transition
341338
342- @partial (jax .jit , static_argnames = ("self" , "play_step_fn" , "env_batch_size" ))
343- def eval_policy_fn (
339+ def eval_policy_fn ( # type: ignore
344340 self ,
345341 training_state : DadsTrainingState ,
346342 eval_env_first_state : EnvState ,
347343 play_step_fn : Callable [
348- [EnvState , Params , RNGKey ],
349- Tuple [EnvState , Params , RNGKey , QDTransition ],
344+ [EnvState , Params ],
345+ Tuple [EnvState , Params , QDTransition ],
350346 ],
351347 env_batch_size : int ,
352348 ) -> Tuple [Reward , Reward , Reward , StateDescriptor ]:
@@ -400,7 +396,6 @@ def eval_policy_fn(
400396
401397 return true_return , true_returns , diversity_returns , transitions .state_desc
402398
403- @partial (jax .jit , static_argnames = ("self" ,))
404399 def _compute_reward (
405400 self , transition : QDTransition , training_state : DadsTrainingState
406401 ) -> Reward :
@@ -417,7 +412,6 @@ def _compute_reward(
417412 transition = transition , training_state = training_state
418413 )
419414
420- @partial (jax .jit , static_argnames = ("self" ,))
421415 def _update_dynamics (
422416 self , operand : Tuple [DadsTrainingState , QDTransition ]
423417 ) -> Tuple [Params , float , optax .OptState ]:
@@ -448,7 +442,6 @@ def _update_dynamics(
448442 dynamics_optimizer_state ,
449443 )
450444
451- @partial (jax .jit , static_argnames = ("self" ,))
452445 def _not_update_dynamics (
453446 self , operand : Tuple [DadsTrainingState , QDTransition ]
454447 ) -> Tuple [Params , float , optax .OptState ]:
@@ -464,7 +457,6 @@ def _not_update_dynamics(
464457 training_state .dynamics_optimizer_state ,
465458 )
466459
467- @partial (jax .jit , static_argnames = ("self" ,))
468460 def _update_networks (
469461 self ,
470462 training_state : DadsTrainingState ,
@@ -566,7 +558,6 @@ def _update_networks(
566558
567559 return new_training_state , metrics
568560
569- @partial (jax .jit , static_argnames = ("self" ,))
570561 def update (
571562 self ,
572563 training_state : DadsTrainingState ,
0 commit comments