Skip to content

Extend BatchNorm with rank-matched channels and optional momentum#278

Open
rsuderman wants to merge 4 commits intoiree-org:mainfrom
rsuderman:batchnorm_rework
Open

Extend BatchNorm with rank-matched channels and optional momentum#278
rsuderman wants to merge 4 commits intoiree-org:mainfrom
rsuderman:batchnorm_rework

Conversation

@rsuderman
Copy link
Copy Markdown
Contributor

@rsuderman rsuderman commented Mar 27, 2026

  • Rank-matched channel tensors: SCALE, BIAS, MEAN, VAR, SAVED_MEAN, and SAVED_INV_VARIANCE now accept tensors of shape [1, C, 1, ..., 1] (ones in all non-feature dimensions) in addition to the canonical 1D [C] shape. The node validates and infers strides for both forms. The ASM emitter collapses rank-matched inputs to 1D via flatten.using_ints before native_batch_norm, and expands rank-matched stat outputs back via reshape after.

  • Optional momentum: Momentum is no longer required. When omitted, the ASM emitter emits torch.constant.float 1.000000e-01 (PyTorch default). If provided it must still be a scalar constant.

  • Tests:

    • test_batchnorm_node.cpp: updated "Momentum missing" section to expect success; added rank-matched validation sections.
    • test_batchnorm_infer_asm_emitter_nchw_rank_matched.cpp: lit test verifying flatten.using_ints collapse for rank-matched MEAN/VAR.
    • test_batchnorm_infer_asm_emitter_nchw_no_momentum.cpp: lit test verifying torch.constant.float default when momentum is absent.
  • Samples:

    • batchnorm_infer_nchw_rank_matched_scale_bias.cpp: end-to-end inference with [1,C,1,1] scale, bias, mean, var.
    • batchnorm_infer_nchw_no_momentum.cpp: end-to-end inference with no momentum tensor.
    • batchnorm_infer_ncdhw_rank_matched.cpp: end-to-end 5D (NCDHW) inference with rank-matched [1,C,1,1,1] channel tensors.

… momentum

- **Rank-matched channel tensors**: SCALE, BIAS, MEAN, VAR, SAVED_MEAN, and
  SAVED_INV_VARIANCE now accept tensors of shape [1, C, 1, ..., 1] (ones in
  all non-feature dimensions) in addition to the canonical 1D [C] shape.
  The node validates and infers strides for both forms.  The ASM emitter
  collapses rank-matched inputs to 1D via `flatten.using_ints` before
  `native_batch_norm`, and expands rank-matched stat outputs back via
  `reshape` after.

- **Optional momentum**: Momentum is no longer required.  When omitted, the
  ASM emitter emits `torch.constant.float 1.000000e-01` (PyTorch default).
  If provided it must still be a scalar constant.

- **Tests**:
  - `test_batchnorm_node.cpp`: updated "Momentum missing" section to expect
    success; added rank-matched validation sections.
  - `test_batchnorm_infer_asm_emitter_nchw_rank_matched.cpp`: lit test
    verifying `flatten.using_ints` collapse for rank-matched MEAN/VAR.
  - `test_batchnorm_infer_asm_emitter_nchw_no_momentum.cpp`: lit test
    verifying `torch.constant.float` default when momentum is absent.

- **Samples**:
  - `batchnorm_infer_nchw_rank_matched_scale_bias.cpp`: end-to-end inference
    with [1,C,1,1] scale, bias, mean, var.
  - `batchnorm_infer_nchw_no_momentum.cpp`: end-to-end inference with no
    momentum tensor.
  - `batchnorm_infer_ncdhw_rank_matched.cpp`: end-to-end 5D (NCDHW) inference
    with rank-matched [1,C,1,1,1] channel tensors.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Rob Suderman <rob.suderman@gmail.com>
@rsuderman rsuderman changed the title feat: Extend BatchNorm with rank-matched channel tensors and optional… Extend BatchNorm with rank-matched channels and optional momentum Mar 27, 2026
@rsuderman rsuderman requested a review from IanWood1 March 27, 2026 20:14
…atched

Channel tensors (scale, bias, mean, var, saved stats) must now be
rank-matched [1, C, 1, ..., 1] rather than accepting the canonical
1D [C] form. Inference and shape logic updated accordingly; all
batchnorm unit and lit tests updated to use rank-matched tensors.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
"BatchNorm epsilon must be a scalar constant");

// Momentum checks.
// Momentum checks (optional — omitting uses the PyTorch default 0.1).
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

omitting uses the PyTorch default 0.1

Is this the same behavior as hipdnn/cudnn?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mentioned above - it should be an optional value altogether, e.g. inference mode doesn't use it. But it does not even appear consistent without pytorch's usecase. I won't be surprised if we have to come by and tweak this once we determine the pytorch integration plan.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took a look at HipDNN and it appears momentum is optional but never given a default. Is the pytorch integration possible by setting a default value there instead of here? I think we want to keep behavior in line with hipdnn as much as possible.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However pytorch does not even use it for training mode. Its just kinda buggy.

I'm not sure what this means. Does pytorch always omit momentum in their testing and just use the defafult (0.1)?

Comment on lines +231 to +232
"BatchNorm tensor " + name +
" must be rank-matched with ones in all non-feature dimensions");
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be worth making this message specific to the unit dims and the one below specifically about the rank mismatch

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just stripped the check. With the rank matching case being the only route we should be fine with it coming together.

rsuderman and others added 2 commits April 1, 2026 14:11
- Update asm_emitter.h comments to reflect that only rank-matched
  [1,C,1,...,1] channel tensors are accepted (1D path was removed)
- Remove dead 1D branch from getBnExpandStatOutputOpsAsm and replace
  statResultTag lambda with direct "_raw" suffix
- Replace torch.aten.reshape with torch.aten.unflatten.int in
  getBnExpandStatOutputOpsAsm (natural inverse of flatten.using_ints)
- Simplify channel tensor stride inference in batchnorm_node.h: known
  shape [1,C,1,...,1] yields stride [C,1,1,...,1] directly, no loop needed
- Update lit test CHECK lines for unflatten.int

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@rsuderman rsuderman requested a review from IanWood1 April 1, 2026 21:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants