Extend BatchNorm with rank-matched channels and optional momentum#278
Extend BatchNorm with rank-matched channels and optional momentum#278rsuderman wants to merge 4 commits intoiree-org:mainfrom
Conversation
… momentum
- **Rank-matched channel tensors**: SCALE, BIAS, MEAN, VAR, SAVED_MEAN, and
SAVED_INV_VARIANCE now accept tensors of shape [1, C, 1, ..., 1] (ones in
all non-feature dimensions) in addition to the canonical 1D [C] shape.
The node validates and infers strides for both forms. The ASM emitter
collapses rank-matched inputs to 1D via `flatten.using_ints` before
`native_batch_norm`, and expands rank-matched stat outputs back via
`reshape` after.
- **Optional momentum**: Momentum is no longer required. When omitted, the
ASM emitter emits `torch.constant.float 1.000000e-01` (PyTorch default).
If provided it must still be a scalar constant.
- **Tests**:
- `test_batchnorm_node.cpp`: updated "Momentum missing" section to expect
success; added rank-matched validation sections.
- `test_batchnorm_infer_asm_emitter_nchw_rank_matched.cpp`: lit test
verifying `flatten.using_ints` collapse for rank-matched MEAN/VAR.
- `test_batchnorm_infer_asm_emitter_nchw_no_momentum.cpp`: lit test
verifying `torch.constant.float` default when momentum is absent.
- **Samples**:
- `batchnorm_infer_nchw_rank_matched_scale_bias.cpp`: end-to-end inference
with [1,C,1,1] scale, bias, mean, var.
- `batchnorm_infer_nchw_no_momentum.cpp`: end-to-end inference with no
momentum tensor.
- `batchnorm_infer_ncdhw_rank_matched.cpp`: end-to-end 5D (NCDHW) inference
with rank-matched [1,C,1,1,1] channel tensors.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Rob Suderman <rob.suderman@gmail.com>
5e73229 to
f50f15c
Compare
…atched Channel tensors (scale, bias, mean, var, saved stats) must now be rank-matched [1, C, 1, ..., 1] rather than accepting the canonical 1D [C] form. Inference and shape logic updated accordingly; all batchnorm unit and lit tests updated to use rank-matched tensors. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
| "BatchNorm epsilon must be a scalar constant"); | ||
|
|
||
| // Momentum checks. | ||
| // Momentum checks (optional — omitting uses the PyTorch default 0.1). |
There was a problem hiding this comment.
omitting uses the PyTorch default 0.1
Is this the same behavior as hipdnn/cudnn?
There was a problem hiding this comment.
Mentioned above - it should be an optional value altogether, e.g. inference mode doesn't use it. But it does not even appear consistent without pytorch's usecase. I won't be surprised if we have to come by and tweak this once we determine the pytorch integration plan.
There was a problem hiding this comment.
I took a look at HipDNN and it appears momentum is optional but never given a default. Is the pytorch integration possible by setting a default value there instead of here? I think we want to keep behavior in line with hipdnn as much as possible.
There was a problem hiding this comment.
However pytorch does not even use it for training mode. Its just kinda buggy.
I'm not sure what this means. Does pytorch always omit momentum in their testing and just use the defafult (0.1)?
| "BatchNorm tensor " + name + | ||
| " must be rank-matched with ones in all non-feature dimensions"); |
There was a problem hiding this comment.
It might be worth making this message specific to the unit dims and the one below specifically about the rank mismatch
There was a problem hiding this comment.
I just stripped the check. With the rank matching case being the only route we should be fine with it coming together.
- Update asm_emitter.h comments to reflect that only rank-matched [1,C,1,...,1] channel tensors are accepted (1D path was removed) - Remove dead 1D branch from getBnExpandStatOutputOpsAsm and replace statResultTag lambda with direct "_raw" suffix - Replace torch.aten.reshape with torch.aten.unflatten.int in getBnExpandStatOutputOpsAsm (natural inverse of flatten.using_ints) - Simplify channel tensor stride inference in batchnorm_node.h: known shape [1,C,1,...,1] yields stride [C,1,1,...,1] directly, no loop needed - Update lit test CHECK lines for unflatten.int Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Rank-matched channel tensors: SCALE, BIAS, MEAN, VAR, SAVED_MEAN, and SAVED_INV_VARIANCE now accept tensors of shape [1, C, 1, ..., 1] (ones in all non-feature dimensions) in addition to the canonical 1D [C] shape. The node validates and infers strides for both forms. The ASM emitter collapses rank-matched inputs to 1D via
flatten.using_intsbeforenative_batch_norm, and expands rank-matched stat outputs back viareshapeafter.Optional momentum: Momentum is no longer required. When omitted, the ASM emitter emits
torch.constant.float 1.000000e-01(PyTorch default). If provided it must still be a scalar constant.Tests:
test_batchnorm_node.cpp: updated "Momentum missing" section to expect success; added rank-matched validation sections.test_batchnorm_infer_asm_emitter_nchw_rank_matched.cpp: lit test verifyingflatten.using_intscollapse for rank-matched MEAN/VAR.test_batchnorm_infer_asm_emitter_nchw_no_momentum.cpp: lit test verifyingtorch.constant.floatdefault when momentum is absent.Samples:
batchnorm_infer_nchw_rank_matched_scale_bias.cpp: end-to-end inference with [1,C,1,1] scale, bias, mean, var.batchnorm_infer_nchw_no_momentum.cpp: end-to-end inference with no momentum tensor.batchnorm_infer_ncdhw_rank_matched.cpp: end-to-end 5D (NCDHW) inference with rank-matched [1,C,1,1,1] channel tensors.