6363from pytensor .compile .mode import Mode , get_mode
6464from pytensor .compile .profiling import register_profiler_printer
6565from pytensor .configdefaults import config
66- from pytensor .gradient import DisconnectedType , NullType , Rop , grad , grad_undefined
66+ from pytensor .gradient import (
67+ DisconnectedType ,
68+ NullType ,
69+ Rop ,
70+ disconnected_type ,
71+ grad ,
72+ grad_undefined ,
73+ )
6774from pytensor .graph .basic import (
6875 Apply ,
6976 Variable ,
@@ -3073,7 +3080,7 @@ def compute_all_gradients(known_grads):
30733080 )
30743081 outputs = local_op (* outer_inputs , return_list = True )
30753082 # Re-order the gradients correctly
3076- gradients = [DisconnectedType ()()]
3083+ gradients = [disconnected_type ()] # n_steps is disconnected
30773084
30783085 offset = info .n_mit_mot + info .n_mit_sot + info .n_sit_sot + n_sitsot_outs
30793086 for p , (x , t ) in enumerate (
@@ -3098,7 +3105,7 @@ def compute_all_gradients(known_grads):
30983105 else :
30993106 gradients .append (x [::- 1 ])
31003107 elif t == "disconnected" :
3101- gradients .append (DisconnectedType () ())
3108+ gradients .append (disconnected_type ())
31023109 elif t == "through_untraced" :
31033110 gradients .append (
31043111 grad_undefined (
@@ -3126,7 +3133,7 @@ def compute_all_gradients(known_grads):
31263133 else :
31273134 gradients .append (x [::- 1 ])
31283135 elif t == "disconnected" :
3129- gradients .append (DisconnectedType () ())
3136+ gradients .append (disconnected_type ())
31303137 elif t == "through_untraced" :
31313138 gradients .append (
31323139 grad_undefined (
@@ -3149,15 +3156,15 @@ def compute_all_gradients(known_grads):
31493156 if not isinstance (dC_dout .type , DisconnectedType ) and connected :
31503157 disconnected = False
31513158 if disconnected :
3152- gradients .append (DisconnectedType () ())
3159+ gradients .append (disconnected_type ())
31533160 else :
31543161 gradients .append (
31553162 grad_undefined (
31563163 self , idx , inputs [idx ], "Shared Variable with update"
31573164 )
31583165 )
31593166
3160- gradients += [ DisconnectedType () () for _ in range (info .n_nit_sot )]
3167+ gradients . extend ( disconnected_type () for _ in range (info .n_nit_sot ))
31613168 begin = end
31623169
31633170 end = begin + n_sitsot_outs
@@ -3167,7 +3174,7 @@ def compute_all_gradients(known_grads):
31673174 if t == "connected" :
31683175 gradients .append (x [- 1 ])
31693176 elif t == "disconnected" :
3170- gradients .append (DisconnectedType () ())
3177+ gradients .append (disconnected_type ())
31713178 elif t == "through_untraced" :
31723179 gradients .append (
31733180 grad_undefined (
@@ -3195,7 +3202,7 @@ def compute_all_gradients(known_grads):
31953202 ):
31963203 disconnected = False
31973204 if disconnected :
3198- gradients [idx ] = DisconnectedType () ()
3205+ gradients [idx ] = disconnected_type ()
31993206 return gradients
32003207
32013208 def R_op (self , inputs , eval_points ):
0 commit comments