@@ -2509,36 +2509,40 @@ def compute_all_gradients(known_grads):
25092509 return rval
25102510
25112511 var_mappings = self .get_oinp_iinp_iout_oout_mappings ()
2512- dC_dinps_t = [None for inp in diff_inputs ]
25132512 disconnected_dC_dinps_t = [True for inp in diff_inputs ]
2513+
2514+ n_mit_mot_outs = info .n_mit_mot_outs
2515+ # In the case of mit-mot there can be more inner outputs than outer ones
2516+ n_extra_mit_mot_outs = n_mit_mot_outs - info .n_mit_mot
2517+ idx_nitsot_out_start = n_mit_mot_outs + info .n_mit_sot + info .n_sit_sot
2518+ idx_nitsot_out_end = idx_nitsot_out_start + info .n_nit_sot
2519+
2520+ # Create dummy variables for the internal input gradients
2521+ states = (
2522+ self .inner_mitmot (self_inputs )
2523+ + self .inner_mitsot (self_inputs )
2524+ + self .inner_sitsot (self_inputs )
2525+ )
25142526 dC_dXts = []
25152527 Xts = []
25162528 for idx , Xt in enumerate (diff_outputs ):
25172529 # We are looking for x[t-1] for a given x[t]
2518- if idx >= info . n_mit_mot_outs :
2530+ if idx >= n_mit_mot_outs :
25192531 Xt_placeholder = safe_new (Xt )
25202532 Xts .append (Xt_placeholder )
25212533
25222534 # Different processing based on whether Xt is a nitsot output
25232535 # or not. NOTE : This cannot be done by using
25242536 # "if Xt not in self.inner_nitsot_outs(self_outputs)" because
25252537 # the exact same variable can be used as multiple outputs.
2526- idx_nitsot_start = info .n_mit_mot + info .n_mit_sot + info .n_sit_sot
2527- idx_nitsot_end = idx_nitsot_start + info .n_nit_sot
2528- if idx < idx_nitsot_start or idx >= idx_nitsot_end :
2538+ if idx < idx_nitsot_out_start or idx >= idx_nitsot_out_end :
25292539 # What we do here is loop through dC_douts and collect all
25302540 # those that are connected to the specific one and do an
25312541 # upcast on all of their dtypes to get the dtype for this
25322542 # specific output. Deciding if the gradient with this
25332543 # specific previous step is defined or not is done somewhere
25342544 # else.
25352545 dtypes = []
2536- states = (
2537- self .inner_mitmot (self_inputs )
2538- + self .inner_mitsot (self_inputs )
2539- + self .inner_sitsot (self_inputs )
2540- )
2541-
25422546 for pos , inp in enumerate (states ):
25432547 if inp in graph_inputs ([Xt ]):
25442548 # Get the index of the outer output that to which
@@ -2555,35 +2559,39 @@ def compute_all_gradients(known_grads):
25552559 new_dtype = config .floatX
25562560 dC_dXt = safe_new (Xt , dtype = new_dtype )
25572561 else :
2558- if isinstance (dC_douts [idx ].type , DisconnectedType ):
2562+ # nit-sot outputs
2563+ # If not disconnected assume the output gradient type is a valid type for the input gradient
2564+ if isinstance (
2565+ dC_douts [idx - n_extra_mit_mot_outs ].type , DisconnectedType
2566+ ):
25592567 continue
2560- dC_dXt = safe_new (dC_douts [idx ][0 ])
2568+ dC_dXt = safe_new (dC_douts [idx - n_extra_mit_mot_outs ][0 ])
25612569 dC_dXts .append (dC_dXt )
25622570
2571+ # Handle cases where the very same variable may be used as different outputs
2572+ # TODO: Couldn't we add a view Op to avoid this when building the Scan graph?
25632573 known_grads = {}
25642574 dc_dxts_idx = 0
25652575 for i in range (len (diff_outputs )):
2566- if i < idx_nitsot_start or i >= idx_nitsot_end :
2567- if diff_outputs [i ] in known_grads :
2568- known_grads [diff_outputs [i ]] += dC_dXts [dc_dxts_idx ]
2569- else :
2570- known_grads [diff_outputs [i ]] = dC_dXts [dc_dxts_idx ]
2571- dc_dxts_idx += 1
2576+ if not (i < idx_nitsot_out_start or i >= idx_nitsot_out_end ) and isinstance (
2577+ dC_douts [i - n_extra_mit_mot_outs ].type , DisconnectedType
2578+ ):
2579+ # Special case where we don't have a dC_dXt for disconnected nitsot outputs
2580+ continue
2581+
2582+ # Just some trouble to avoid a +0
2583+ if diff_outputs [i ] in known_grads :
2584+ known_grads [diff_outputs [i ]] += dC_dXts [dc_dxts_idx ]
25722585 else :
2573- if isinstance (dC_douts [i ].type , DisconnectedType ):
2574- continue
2575- else :
2576- if diff_outputs [i ] in known_grads :
2577- known_grads [diff_outputs [i ]] += dC_dXts [dc_dxts_idx ]
2578- else :
2579- known_grads [diff_outputs [i ]] = dC_dXts [dc_dxts_idx ]
2580- dc_dxts_idx += 1
2586+ known_grads [diff_outputs [i ]] = dC_dXts [dc_dxts_idx ]
2587+ dc_dxts_idx += 1
2588+
25812589 dC_dinps_t = compute_all_gradients (known_grads )
25822590
25832591 # mask inputs that get no gradients
25842592 for dx in range (len (dC_dinps_t )):
2585- if not dC_dinps_t [dx ]:
2586- dC_dinps_t [dx ] = pt .zeros_like (diff_inputs [dx ])
2593+ if dC_dinps_t [dx ] is None :
2594+ dC_dinps_t [dx ] = dC_dinps_t [ dx ] = pt .zeros_like (diff_inputs [dx ])
25872595 else :
25882596 disconnected_dC_dinps_t [dx ] = False
25892597 for Xt , Xt_placeholder in zip (
@@ -2846,7 +2854,6 @@ def compute_all_gradients(known_grads):
28462854 for idx in range (info .n_sit_sot ):
28472855 mitmot_inp_taps .append ([0 , 1 ])
28482856 mitmot_out_taps .append ([1 ])
2849- through_shared = False
28502857 if not isinstance (dC_douts [idx + offset ].type , DisconnectedType ):
28512858 outer_inp_mitmot .append (dC_douts [idx + offset ][::- 1 ])
28522859 else :
@@ -3007,9 +3014,7 @@ def compute_all_gradients(known_grads):
30073014 name = f"grad_of_{ self .name } " if self .name else None ,
30083015 allow_gc = self .allow_gc ,
30093016 )
3010- outputs = local_op (* outer_inputs )
3011- if not isinstance (outputs , list | tuple ):
3012- outputs = [outputs ]
3017+ outputs = local_op (* outer_inputs , return_list = True )
30133018 # Re-order the gradients correctly
30143019 gradients = [DisconnectedType ()()]
30153020
@@ -3095,7 +3100,6 @@ def compute_all_gradients(known_grads):
30953100 )
30963101 )
30973102
3098- start = len (gradients )
30993103 gradients += [DisconnectedType ()() for _ in range (info .n_nit_sot )]
31003104 begin = end
31013105
0 commit comments