@@ -580,11 +580,15 @@ def detach_decomposition(x):
580
580
return x
581
581
582
582
583
- # @register_decomposition(aten.cudnn_batch_norm)
584
- # def cudnn_batch_norm(input: Tensor, weight: Tensor, bias: Optional[Tensor], running_mean: Optional[Tensor], running_var: Optional[Tensor], training: bool, exponential_average_factor: float, epsilon: float):
585
- # a, b, c = aten.native_batch_norm(input, weight, bias, running_mean, running_var, training, exponential_average_factor, epsilon)
586
- # return (a,b, c, aten.new_empty(input, (1,)))
587
-
588
- # @register_decomposition(aten.cudnn_batch_norm_backward)
589
- # def cudnn_batch_norm_backward(input: Tensor, grad_output: Tensor, weight: Tensor, running_mean: Optional[Tensor], running_var: Optional[Tensor], save_mean: Optional[Tensor], save_var: Optional[Tensor], epsilon: float, reserveSpace: Tensor):
590
- # return aten.native_batch_norm_backward(grad_output, input, weight, running_mean, running_var, save_mean, save_var, True, epsilon, [True, True, True])
583
+ @register_decomposition (aten .cudnn_batch_norm )
584
+ def cudnn_batch_norm (input : Tensor , weight : Tensor , bias : Optional [Tensor ], running_mean : Optional [Tensor ], running_var : Optional [Tensor ], training : bool , exponential_average_factor : float , epsilon : float ):
585
+ a , b , c = aten .native_batch_norm (input , weight , bias , running_mean , running_var , training , exponential_average_factor , epsilon )
586
+ # Cudnn return running mean and variance when training is True
587
+ if training :
588
+ return (a , b , c , input .new_zeros ((1 ,)))
589
+ return (a , input .new_zeros ((1 ,)), input .new_zeros ((1 ,)), input .new_zeros ((1 ,)))
590
+
591
+
592
+ @register_decomposition (aten .cudnn_batch_norm_backward )
593
+ def cudnn_batch_norm_backward (input : Tensor , grad_output : Tensor , weight : Tensor , running_mean : Optional [Tensor ], running_var : Optional [Tensor ], save_mean : Optional [Tensor ], save_var : Optional [Tensor ], epsilon : float , reserveSpace : Tensor ):
594
+ return aten .native_batch_norm_backward (grad_output , input , weight , running_mean , running_var , save_mean , save_var , True , epsilon , [True , True , True ])
0 commit comments