Commit f50f15c
feat: Extend BatchNorm with rank-matched channel tensors and optional momentum
- **Rank-matched channel tensors**: SCALE, BIAS, MEAN, VAR, SAVED_MEAN, and
SAVED_INV_VARIANCE now accept tensors of shape [1, C, 1, ..., 1] (ones in
all non-feature dimensions) in addition to the canonical 1D [C] shape.
The node validates and infers strides for both forms. The ASM emitter
collapses rank-matched inputs to 1D via `flatten.using_ints` before
`native_batch_norm`, and expands rank-matched stat outputs back via
`reshape` after.
- **Optional momentum**: Momentum is no longer required. When omitted, the
ASM emitter emits `torch.constant.float 1.000000e-01` (PyTorch default).
If provided it must still be a scalar constant.
- **Tests**:
- `test_batchnorm_node.cpp`: updated "Momentum missing" section to expect
success; added rank-matched validation sections.
- `test_batchnorm_infer_asm_emitter_nchw_rank_matched.cpp`: lit test
verifying `flatten.using_ints` collapse for rank-matched MEAN/VAR.
- `test_batchnorm_infer_asm_emitter_nchw_no_momentum.cpp`: lit test
verifying `torch.constant.float` default when momentum is absent.
- **Samples**:
- `batchnorm_infer_nchw_rank_matched_scale_bias.cpp`: end-to-end inference
with [1,C,1,1] scale, bias, mean, var.
- `batchnorm_infer_nchw_no_momentum.cpp`: end-to-end inference with no
momentum tensor.
- `batchnorm_infer_ncdhw_rank_matched.cpp`: end-to-end 5D (NCDHW) inference
with rank-matched [1,C,1,1,1] channel tensors.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Rob Suderman <rob.suderman@gmail.com>1 parent a3398a3 commit f50f15c
File tree
10 files changed
+976
-103
lines changed- include/fusilli
- node
- support
- samples
- batchnorm
- tests
- lit
10 files changed
+976
-103
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
125 | 125 | | |
126 | 126 | | |
127 | 127 | | |
128 | | - | |
| 128 | + | |
129 | 129 | | |
130 | | - | |
131 | | - | |
132 | | - | |
133 | | - | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
134 | 134 | | |
135 | 135 | | |
136 | 136 | | |
| |||
151 | 151 | | |
152 | 152 | | |
153 | 153 | | |
154 | | - | |
155 | | - | |
| 154 | + | |
| 155 | + | |
| 156 | + | |
| 157 | + | |
| 158 | + | |
| 159 | + | |
| 160 | + | |
| 161 | + | |
| 162 | + | |
| 163 | + | |
| 164 | + | |
| 165 | + | |
| 166 | + | |
156 | 167 | | |
157 | 168 | | |
158 | 169 | | |
| |||
203 | 214 | | |
204 | 215 | | |
205 | 216 | | |
| 217 | + | |
| 218 | + | |
206 | 219 | | |
207 | 220 | | |
208 | 221 | | |
209 | 222 | | |
| 223 | + | |
| 224 | + | |
| 225 | + | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
| 231 | + | |
| 232 | + | |
| 233 | + | |
| 234 | + | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
| 241 | + | |
| 242 | + | |
| 243 | + | |
| 244 | + | |
| 245 | + | |
| 246 | + | |
| 247 | + | |
| 248 | + | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
| 252 | + | |
| 253 | + | |
| 254 | + | |
210 | 255 | | |
211 | | - | |
| 256 | + | |
212 | 257 | | |
213 | | - | |
214 | | - | |
215 | | - | |
216 | | - | |
217 | | - | |
| 258 | + | |
| 259 | + | |
218 | 260 | | |
219 | 261 | | |
220 | 262 | | |
| |||
0 commit comments