Skip to content

Commit f50f15c

Browse files
rsudermanclaude
andcommitted
feat: Extend BatchNorm with rank-matched channel tensors and optional momentum
- **Rank-matched channel tensors**: SCALE, BIAS, MEAN, VAR, SAVED_MEAN, and SAVED_INV_VARIANCE now accept tensors of shape [1, C, 1, ..., 1] (ones in all non-feature dimensions) in addition to the canonical 1D [C] shape. The node validates and infers strides for both forms. The ASM emitter collapses rank-matched inputs to 1D via `flatten.using_ints` before `native_batch_norm`, and expands rank-matched stat outputs back via `reshape` after. - **Optional momentum**: Momentum is no longer required. When omitted, the ASM emitter emits `torch.constant.float 1.000000e-01` (PyTorch default). If provided it must still be a scalar constant. - **Tests**: - `test_batchnorm_node.cpp`: updated "Momentum missing" section to expect success; added rank-matched validation sections. - `test_batchnorm_infer_asm_emitter_nchw_rank_matched.cpp`: lit test verifying `flatten.using_ints` collapse for rank-matched MEAN/VAR. - `test_batchnorm_infer_asm_emitter_nchw_no_momentum.cpp`: lit test verifying `torch.constant.float` default when momentum is absent. - **Samples**: - `batchnorm_infer_nchw_rank_matched_scale_bias.cpp`: end-to-end inference with [1,C,1,1] scale, bias, mean, var. - `batchnorm_infer_nchw_no_momentum.cpp`: end-to-end inference with no momentum tensor. - `batchnorm_infer_ncdhw_rank_matched.cpp`: end-to-end 5D (NCDHW) inference with rank-matched [1,C,1,1,1] channel tensors. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Rob Suderman <rob.suderman@gmail.com>
1 parent a3398a3 commit f50f15c

10 files changed

+976
-103
lines changed

include/fusilli/node/batchnorm_node.h

Lines changed: 55 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -125,12 +125,12 @@ class BatchNormNode : public NodeCRTP<BatchNormNode> {
125125
FUSILLI_RETURN_ERROR_IF(!eT->isScalar(), ErrorCode::InvalidAttribute,
126126
"BatchNorm epsilon must be a scalar constant");
127127

128-
// Momentum checks.
128+
// Momentum checks (optional — omitting uses the PyTorch default 0.1).
129129
std::shared_ptr<TensorAttr> mT = batchnormAttr.getMomentum();
130-
FUSILLI_RETURN_ERROR_IF(!mT, ErrorCode::AttributeNotSet,
131-
"BatchNorm momentum not set");
132-
FUSILLI_RETURN_ERROR_IF(!mT->isScalar(), ErrorCode::InvalidAttribute,
133-
"BatchNorm momentum must be a scalar constant");
130+
if (mT) {
131+
FUSILLI_RETURN_ERROR_IF(!mT->isScalar(), ErrorCode::InvalidAttribute,
132+
"BatchNorm momentum must be a scalar constant");
133+
}
134134

135135
return ok();
136136
}
@@ -151,8 +151,19 @@ class BatchNormNode : public NodeCRTP<BatchNormNode> {
151151
auto infer1DTensor = [&](const std::shared_ptr<TensorAttr> &t) {
152152
if (t->getDim().empty())
153153
t->setDim(channel1DDim);
154-
if (t->getStride().empty())
155-
t->setStride(channel1DStride);
154+
if (t->getStride().empty()) {
155+
const std::vector<int64_t> &dim = t->getDim();
156+
size_t rank = dim.size();
157+
if (rank == xDim.size()) {
158+
// Rank-matched tensor [1, C, 1, ...]: compute contiguous stride.
159+
std::vector<int64_t> stride(rank, 1);
160+
for (int64_t i = (int64_t)rank - 2; i >= 0; --i)
161+
stride[i] = stride[i + 1] * dim[i + 1];
162+
t->setStride(stride);
163+
} else {
164+
t->setStride(channel1DStride);
165+
}
166+
}
156167
};
157168

158169
// Infer 1D channel tensors.
@@ -203,18 +214,49 @@ class BatchNormNode : public NodeCRTP<BatchNormNode> {
203214
"defined by its stride");
204215

205216
// Shape checks for 1D channel tensors.
217+
// Accepts either the canonical 1D form [C] with unit stride, or a
218+
// rank-matched form [1, C, 1, ..., 1] with unit stride at the channel dim.
206219
auto check1DShape = [&](const std::shared_ptr<TensorAttr> &t,
207220
const std::string &name) -> ErrorObject {
208221
if (!t)
209222
return ok();
223+
const std::vector<int64_t> &tDim = t->getDim();
224+
const std::vector<int64_t> &tStride = t->getStride();
225+
size_t xRank = xDim.size();
226+
227+
if (tDim == expectedCDim) {
228+
// Canonical 1D form [C]: stride must be {1}.
229+
FUSILLI_RETURN_ERROR_IF(
230+
tStride != std::vector<int64_t>{1}, ErrorCode::InvalidAttribute,
231+
"BatchNorm tensor " + name + " must have unit stride");
232+
return ok();
233+
}
234+
235+
if (tDim.size() == xRank) {
236+
// Rank-matched form [1, C, 1, ..., 1]: channel dim must equal C and
237+
// all other dims must be 1.
238+
bool validShape = (tDim[1] == xDim[1]);
239+
for (size_t i = 0; i < xRank && validShape; ++i)
240+
if (i != 1 && tDim[i] != 1)
241+
validShape = false;
242+
FUSILLI_RETURN_ERROR_IF(
243+
!validShape, ErrorCode::InvalidAttribute,
244+
"BatchNorm tensor " + name +
245+
" must be 1D with size equal to channel dimension C"
246+
" or rank-matched with ones in all non-feature dimensions");
247+
FUSILLI_RETURN_ERROR_IF(
248+
tStride.size() != xRank || tStride[1] != 1,
249+
ErrorCode::InvalidAttribute,
250+
"BatchNorm tensor " + name +
251+
" must have unit stride at the channel dimension");
252+
return ok();
253+
}
254+
210255
FUSILLI_RETURN_ERROR_IF(
211-
t->getDim() != expectedCDim, ErrorCode::InvalidAttribute,
256+
true, ErrorCode::InvalidAttribute,
212257
"BatchNorm tensor " + name +
213-
" must be 1D with size equal to channel dimension C");
214-
FUSILLI_RETURN_ERROR_IF(t->getStride() != std::vector<int64_t>{1},
215-
ErrorCode::InvalidAttribute,
216-
"BatchNorm tensor " + name +
217-
" must have unit stride");
258+
" must be 1D with size equal to channel dimension C"
259+
" or rank-matched with ones in all non-feature dimensions");
218260
return ok();
219261
};
220262

0 commit comments

Comments
 (0)