Skip to content

Commit 015f8b6

Browse files
authored
Add cudnn_batch_norm decomposition to default nvfuser decompositions (#661)
* Add cudnn_batch_norm decomposition to default nvfuser decompositions * Comments * Revert zeros change * Using new_zeros
1 parent f076eaf commit 015f8b6

File tree

2 files changed

+14
-8
lines changed

2 files changed

+14
-8
lines changed

functorch/_src/compilers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,8 @@ def nnc_jit(f, static_argnums=None):
265265
aten.hardswish_backward,
266266
aten.tanh_backward,
267267
aten.silu_backward,
268+
aten.cudnn_batch_norm,
269+
aten.cudnn_batch_norm_backward,
268270
]
269271
)
270272
default_decompositions = get_decompositions(default_decompositions)

functorch/_src/decompositions.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -580,11 +580,15 @@ def detach_decomposition(x):
580580
return x
581581

582582

583-
# @register_decomposition(aten.cudnn_batch_norm)
584-
# def cudnn_batch_norm(input: Tensor, weight: Tensor, bias: Optional[Tensor], running_mean: Optional[Tensor], running_var: Optional[Tensor], training: bool, exponential_average_factor: float, epsilon: float):
585-
# a, b, c = aten.native_batch_norm(input, weight, bias, running_mean, running_var, training, exponential_average_factor, epsilon)
586-
# return (a,b, c, aten.new_empty(input, (1,)))
587-
588-
# @register_decomposition(aten.cudnn_batch_norm_backward)
589-
# def cudnn_batch_norm_backward(input: Tensor, grad_output: Tensor, weight: Tensor, running_mean: Optional[Tensor], running_var: Optional[Tensor], save_mean: Optional[Tensor], save_var: Optional[Tensor], epsilon: float, reserveSpace: Tensor):
590-
# return aten.native_batch_norm_backward(grad_output, input, weight, running_mean, running_var, save_mean, save_var, True, epsilon, [True, True, True])
583+
@register_decomposition(aten.cudnn_batch_norm)
584+
def cudnn_batch_norm(input: Tensor, weight: Tensor, bias: Optional[Tensor], running_mean: Optional[Tensor], running_var: Optional[Tensor], training: bool, exponential_average_factor: float, epsilon: float):
585+
a, b, c = aten.native_batch_norm(input, weight, bias, running_mean, running_var, training, exponential_average_factor, epsilon)
586+
# Cudnn return running mean and variance when training is True
587+
if training:
588+
return (a, b, c, input.new_zeros((1,)))
589+
return (a, input.new_zeros((1,)), input.new_zeros((1,)), input.new_zeros((1,)))
590+
591+
592+
@register_decomposition(aten.cudnn_batch_norm_backward)
593+
def cudnn_batch_norm_backward(input: Tensor, grad_output: Tensor, weight: Tensor, running_mean: Optional[Tensor], running_var: Optional[Tensor], save_mean: Optional[Tensor], save_var: Optional[Tensor], epsilon: float, reserveSpace: Tensor):
594+
return aten.native_batch_norm_backward(grad_output, input, weight, running_mean, running_var, save_mean, save_var, True, epsilon, [True, True, True])

0 commit comments

Comments
 (0)