@@ -300,6 +300,7 @@ def loss_fun(inputs, targets):
300300 if self .jit [c .LOSS_PHASE ] and jit :
301301 dyn_vars = self .target .vars ()
302302 dyn_vars .update (self .dyn_vars )
303+ dyn_vars = dyn_vars - dyn_vars .subset (bm .VariableView )
303304 self ._f_loss_compiled [shared_args_str ] = bm .jit (self ._f_loss_compiled [shared_args_str ],
304305 dyn_vars = dyn_vars )
305306 return self ._f_loss_compiled [shared_args_str ]
@@ -311,6 +312,7 @@ def f_grad(self, shared_args=None) -> Callable:
311312 _f_loss_internal = self .f_loss (shared_args , jit = False )
312313 dyn_vars = self .target .vars ()
313314 dyn_vars .update (self .dyn_vars )
315+ dyn_vars = dyn_vars - dyn_vars .subset (bm .VariableView )
314316 tran_vars = dyn_vars .subset (bm .TrainVar )
315317 grad_f = bm .grad (_f_loss_internal ,
316318 dyn_vars = dyn_vars .unique (),
@@ -339,6 +341,7 @@ def train_func(inputs, targets):
339341 dyn_vars = self .target .vars ()
340342 dyn_vars .update (self .dyn_vars )
341343 dyn_vars .update (self .optimizer .vars ())
344+ dyn_vars = dyn_vars - dyn_vars .subset (bm .VariableView )
342345 self ._f_train_compiled [shared_args_str ] = bm .jit (train_func , dyn_vars = dyn_vars .unique ())
343346 else :
344347 self ._f_train_compiled [shared_args_str ] = train_func
@@ -453,6 +456,7 @@ def loss_fun(inputs, targets):
453456 if self .jit [c .LOSS_PHASE ] and jit :
454457 dyn_vars = self .target .vars ()
455458 dyn_vars .update (self .dyn_vars )
459+ dyn_vars = dyn_vars - dyn_vars .subset (bm .VariableView )
456460 self ._f_loss_compiled [shared_args_str ] = bm .jit (self ._f_loss_compiled [shared_args_str ],
457461 dyn_vars = dyn_vars )
458462 else :
@@ -480,6 +484,7 @@ def run_func(xs):
480484 if self .jit [c .PREDICT_PHASE ] and jit :
481485 dyn_vars = self .target .vars ()
482486 dyn_vars .update (self .dyn_vars )
487+ dyn_vars = dyn_vars - dyn_vars .subset (bm .VariableView )
483488 self ._f_predict_compiled [shared_args_str ] = bm .jit (run_func , dyn_vars = dyn_vars .unique ())
484489 else :
485490 self ._f_predict_compiled [shared_args_str ] = run_func
@@ -505,6 +510,7 @@ def loss_fun(t, i, input_, target_):
505510 if self .jit [c .LOSS_PHASE ] and jit :
506511 dyn_vars = self .target .vars ()
507512 dyn_vars .update (self .dyn_vars )
513+ dyn_vars = dyn_vars - dyn_vars .subset (bm .VariableView )
508514 self ._f_loss_compiled [shared_args_str ] = bm .jit (self ._f_loss_compiled [shared_args_str ],
509515 dyn_vars = dyn_vars )
510516 else :
@@ -529,6 +535,7 @@ def train_step(*x):
529535 if self .jit [c .FIT_PHASE ]:
530536 dyn_vars = self .target .vars ()
531537 dyn_vars .update (self .dyn_vars )
538+ dyn_vars = dyn_vars - dyn_vars .subset (bm .VariableView )
532539 run_func = lambda all_inputs : bm .for_loop (train_step , dyn_vars .unique (), all_inputs )
533540
534541 else :
@@ -582,6 +589,7 @@ def run_func(t, i, x):
582589 if self .jit [c .FIT_PHASE ] and jit :
583590 dyn_vars = self .target .vars ()
584591 dyn_vars .update (self .dyn_vars )
592+ dyn_vars = dyn_vars - dyn_vars .subset (bm .VariableView )
585593 self ._f_predict_compiled [shared_args_str ] = bm .jit (run_func , dyn_vars = dyn_vars .unique ())
586594 else :
587595 self ._f_predict_compiled [shared_args_str ] = run_func
0 commit comments