@@ -133,7 +133,6 @@ def logp(
133133
134134 """
135135 if hgf .model_type == "continuous" :
136-
137136 # update this network's attributes
138137 hgf .attributes [0 ]["precision" ] = input_precision
139138
@@ -161,7 +160,6 @@ def logp(
161160 hgf .attributes [3 ]["volatility_coupling_children" ] = (volatility_coupling_2 ,)
162161
163162 elif hgf .model_type == "binary" :
164-
165163 # update this network's attributes
166164 hgf .attributes [0 ]["mean" ] = mean_1
167165 hgf .attributes [1 ]["mean" ] = mean_2
@@ -486,23 +484,26 @@ def perform(
486484 ):
487485 """Perform node operations."""
488486 (
489- grad_mean_1 ,
490- grad_mean_2 ,
491- grad_mean_3 ,
492- grad_precision_1 ,
493- grad_precision_2 ,
494- grad_precision_3 ,
495- grad_tonic_volatility_1 ,
496- grad_tonic_volatility_2 ,
497- grad_tonic_volatility_3 ,
498- grad_tonic_drift_1 ,
499- grad_tonic_drift_2 ,
500- grad_tonic_drift_3 ,
501- grad_volatility_coupling_1 ,
502- grad_volatility_coupling_2 ,
503- grad_input_precision ,
504- grad_response_function_parameters ,
505- ), _ = self .grad_logp (* inputs )
487+ (
488+ grad_mean_1 ,
489+ grad_mean_2 ,
490+ grad_mean_3 ,
491+ grad_precision_1 ,
492+ grad_precision_2 ,
493+ grad_precision_3 ,
494+ grad_tonic_volatility_1 ,
495+ grad_tonic_volatility_2 ,
496+ grad_tonic_volatility_3 ,
497+ grad_tonic_drift_1 ,
498+ grad_tonic_drift_2 ,
499+ grad_tonic_drift_3 ,
500+ grad_volatility_coupling_1 ,
501+ grad_volatility_coupling_2 ,
502+ grad_input_precision ,
503+ grad_response_function_parameters ,
504+ ),
505+ _ ,
506+ ) = self .grad_logp (* inputs )
506507
507508 outputs [0 ][0 ] = np .asarray (grad_mean_1 , dtype = node .outputs [0 ].dtype )
508509 outputs [1 ][0 ] = np .asarray (grad_mean_2 , dtype = node .outputs [1 ].dtype )
0 commit comments