Skip to content

Implement BatchnormBwd functionality into hip-kernel-provider#5558

Draft
BalintCsala wants to merge 11 commits intoROCm:developfrom
StreamHPC:users/BalintCsala/hipdnn-direct-batchnorm-bwd-rebased
Draft

Implement BatchnormBwd functionality into hip-kernel-provider#5558
BalintCsala wants to merge 11 commits intoROCm:developfrom
StreamHPC:users/BalintCsala/hipdnn-direct-batchnorm-bwd-rebased

Conversation

@BalintCsala
Copy link
Contributor

Caution

Must be merged after #4885

Motivation

This PR moves the BatchnormBwd implementation of MIOpen into the hip-kernel-provider hipDNN plugin.

Technical Details

  • Backward batchnorm plan (BatchnormBwdPlan): Implements miopenBatchNormalizationBackward_V2 with support for both spatial and per-activation modes. Selects between single-kernel and multi-kernel (stash-based) variants via a heuristic based on tensor dimensions.
  • Fused backward+activation: Supports miopenBatchNormBackwardActivation by fusing a pointwise activation backward pass (ReLU, sigmoid, tanh, ELU, etc.) into the batchnorm backward kernel.
  • Plan routing (BatchnormPlanBuilder): Extended isApplicable() and buildPlan() to handle single-node backward (1-node graph) and fused backward+activation (3-node graph: BnInference → Pointwise → BnBackward).
  • Applicability checks: Added checkBatchnormBackwardTensorConfigSupported(), checkBatchnormInferenceActivationBackwardTensorConfigSupported(), and checkBatchnormBwdActivationModeSupported() for validating tensor layouts and data types.
  • Activation utilities (HipKernelUtils): Added ActivationMode enum, ActivationParams struct, and parseActivation() to map PointwiseAttributes to activation parameters used by backward kernels.
  • HIP kernels: Added BatchNormBwdSpatial.cpp, BatchNormBwdSpatialMultiple.cpp (multi-kernel stash variant), BatchNormBwdPerAct.cpp, and BatchnormStash.hpp.
  • Integration tests: Backward and backward+activation tests across NCHW/NHWC/NCDHW/NDHWC layouts with fp32/fp16/bfloat16.
  • Unit test: Validates BatchnormBwdParams construction for single backward and fused activation cases.

Test Plan

Testing done through included unit and integration tests.

Test Result

Tests passed on gfx90a

Submission Checklist

…lementing batchnorm forward training functionality

The code builds and runs, but the results are incorrect.
… for fused batchnorm forward training & added related tests
…indTensorAttributes for better exception handling
…aunching batchnorm fwd train spatial variant 2 kernel
Its inclusion doesn't seem to be necessary for the functionality provided by StaticUnroll.hpp
…iles related to batchnorm forward training functionality
Add backward batchnorm operations (miopenBatchNormalizationBackward_V2) and
fused backward+activation (miopenBatchNormBackwardActivation) support to the
hip-kernel-provider plugin.

New files:
- BatchnormBwdPlan.hpp/cpp: Plan implementation for backward batchnorm
  with spatial (single and multi-kernel) and per-activation modes
- BatchNormBwdSpatial.cpp: Spatial backward HIP kernel
- BatchNormBwdSpatialMultiple.cpp: Multi-pass spatial backward HIP kernel
- BatchNormBwdPerAct.cpp: Per-activation backward HIP kernel
- BatchnormStash.hpp: Stash utilities for multi-kernel approach

Modified files:
- BatchnormPlanBuilder.cpp: Route backward graphs (1-node and 3-node
  fused BnInference+Pointwise+BnBackward) to backward plan
- BatchnormApplicabilityChecks.hpp/cpp: Add backward tensor config and
  activation mode validation
- HipKernelUtils.hpp/cpp: Add findDeviceBuffer declaration and correct
  parseActivation implementation with full activation mode support

Tests:
- Integration tests for backward and backward+activation operations
  across NCHW/NHWC/NCDHW/NDHWC layouts with fp32/fp16/bfloat16
- Unit tests for BatchnormBwdParams initialization
@BalintCsala BalintCsala force-pushed the users/BalintCsala/hipdnn-direct-batchnorm-bwd-rebased branch from e63e7ab to 37e23ba Compare March 18, 2026 17:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants