Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 64 additions & 37 deletions include/fusilli/node/batchnorm_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ namespace fusilli {
// each channel C. The input X has logical shape [N, C, *] where C is at
// dimension 1 in logical (NCHW) order.
//
// Scale (gamma), bias (beta), running mean, and running variance are all 1D
// tensors of shape [C].
// Scale (gamma), bias (beta), running mean, and running variance are all
// rank-matched tensors of shape [1, C, 1, ..., 1].
//
// Inference: requires running MEAN and VAR; outputs Y only.
// Training: running MEAN and VAR are optional; outputs Y, SAVED_MEAN, and
Expand Down Expand Up @@ -125,12 +125,12 @@ class BatchNormNode : public NodeCRTP<BatchNormNode> {
FUSILLI_RETURN_ERROR_IF(!eT->isScalar(), ErrorCode::InvalidAttribute,
"BatchNorm epsilon must be a scalar constant");

// Momentum checks.
// Momentum checks (optional — omitting uses the PyTorch default 0.1).
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

omitting uses the PyTorch default 0.1

Is this the same behavior as hipdnn/cudnn?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mentioned above - it should be an optional value altogether, e.g. inference mode doesn't use it. But it does not even appear consistent without pytorch's usecase. I won't be surprised if we have to come by and tweak this once we determine the pytorch integration plan.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took a look at HipDNN and it appears momentum is optional but never given a default. Is the pytorch integration possible by setting a default value there instead of here? I think we want to keep behavior in line with hipdnn as much as possible.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However pytorch does not even use it for training mode. Its just kinda buggy.

I'm not sure what this means. Does pytorch always omit momentum in their testing and just use the defafult (0.1)?

std::shared_ptr<TensorAttr> mT = batchnormAttr.getMomentum();
FUSILLI_RETURN_ERROR_IF(!mT, ErrorCode::AttributeNotSet,
"BatchNorm momentum not set");
FUSILLI_RETURN_ERROR_IF(!mT->isScalar(), ErrorCode::InvalidAttribute,
"BatchNorm momentum must be a scalar constant");
if (mT) {
FUSILLI_RETURN_ERROR_IF(!mT->isScalar(), ErrorCode::InvalidAttribute,
"BatchNorm momentum must be a scalar constant");
}

return ok();
}
Expand All @@ -145,25 +145,34 @@ class BatchNormNode : public NodeCRTP<BatchNormNode> {
std::shared_ptr<TensorAttr> yT = batchnormAttr.getY();

const std::vector<int64_t> &xDim = xT->getDim();
const std::vector<int64_t> channel1DDim = {xDim[1]};
const std::vector<int64_t> channel1DStride = {1};
size_t xRank = xDim.size();
// Build the rank-matched channel dim: [1, C, 1, ..., 1]
std::vector<int64_t> channelRankMatchedDim(xRank, 1);
channelRankMatchedDim[1] = xDim[1];

auto infer1DTensor = [&](const std::shared_ptr<TensorAttr> &t) {
auto inferChannelTensor = [&](const std::shared_ptr<TensorAttr> &t) {
if (t->getDim().empty())
t->setDim(channel1DDim);
if (t->getStride().empty())
t->setStride(channel1DStride);
t->setDim(channelRankMatchedDim);
if (t->getStride().empty()) {
const std::vector<int64_t> &dim = t->getDim();
size_t rank = dim.size();
// Compute contiguous stride for the rank-matched tensor.
std::vector<int64_t> stride(rank, 1);
for (int64_t i = (int64_t)rank - 2; i >= 0; --i)
stride[i] = stride[i + 1] * dim[i + 1];
t->setStride(stride);
}
};

// Infer 1D channel tensors.
// Infer rank-matched channel tensors.
if (auto sT = batchnormAttr.getSCALE())
infer1DTensor(sT);
inferChannelTensor(sT);
if (auto bT = batchnormAttr.getBIAS())
infer1DTensor(bT);
inferChannelTensor(bT);
if (auto meanT = batchnormAttr.getMEAN())
infer1DTensor(meanT);
inferChannelTensor(meanT);
if (auto varT = batchnormAttr.getVAR())
infer1DTensor(varT);
inferChannelTensor(varT);

// Infer shape and stride of output Y tensor (same as X).
if (yT->getDim().empty())
Expand All @@ -173,8 +182,8 @@ class BatchNormNode : public NodeCRTP<BatchNormNode> {

// Infer saved statistics shapes for training.
if (isTrainingForwardPhase()) {
infer1DTensor(batchnormAttr.getSAVED_MEAN());
infer1DTensor(batchnormAttr.getSAVED_INV_VARIANCE());
inferChannelTensor(batchnormAttr.getSAVED_MEAN());
inferChannelTensor(batchnormAttr.getSAVED_INV_VARIANCE());
}

return ok();
Expand All @@ -188,7 +197,6 @@ class BatchNormNode : public NodeCRTP<BatchNormNode> {
std::shared_ptr<TensorAttr> yT = batchnormAttr.getY();

const std::vector<int64_t> &xDim = xT->getDim();
const std::vector<int64_t> expectedCDim = {xDim[1]};

// Shape check for output Y tensor.
FUSILLI_RETURN_ERROR_IF(
Expand All @@ -202,32 +210,51 @@ class BatchNormNode : public NodeCRTP<BatchNormNode> {
"' is neither contiguous nor channels-last as "
"defined by its stride");

// Shape checks for 1D channel tensors.
auto check1DShape = [&](const std::shared_ptr<TensorAttr> &t,
const std::string &name) -> ErrorObject {
// Shape checks for rank-matched channel tensors of form [1, C, 1, ..., 1].
auto checkChannelShape = [&](const std::shared_ptr<TensorAttr> &t,
const std::string &name) -> ErrorObject {
if (!t)
return ok();
const std::vector<int64_t> &tDim = t->getDim();
const std::vector<int64_t> &tStride = t->getStride();
size_t xRank = xDim.size();

if (tDim.size() == xRank) {
// Rank-matched form [1, C, 1, ..., 1]: channel dim must equal C and
// all other dims must be 1.
bool validShape = (tDim[1] == xDim[1]);
for (size_t i = 0; i < xRank && validShape; ++i)
if (i != 1 && tDim[i] != 1)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be directly used as the condition for FUSILLI_RETURN_ERROR_IF?

validShape = false;
FUSILLI_RETURN_ERROR_IF(
!validShape, ErrorCode::InvalidAttribute,
"BatchNorm tensor " + name +
" must be rank-matched with ones in all non-feature dimensions");
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be worth making this message specific to the unit dims and the one below specifically about the rank mismatch

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just stripped the check. With the rank matching case being the only route we should be fine with it coming together.

FUSILLI_RETURN_ERROR_IF(
tStride.size() != xRank || tStride[1] != 1,
ErrorCode::InvalidAttribute,
"BatchNorm tensor " + name +
" must have unit stride at the channel dimension");
return ok();
}

FUSILLI_RETURN_ERROR_IF(
t->getDim() != expectedCDim, ErrorCode::InvalidAttribute,
true, ErrorCode::InvalidAttribute,
"BatchNorm tensor " + name +
" must be 1D with size equal to channel dimension C");
FUSILLI_RETURN_ERROR_IF(t->getStride() != std::vector<int64_t>{1},
ErrorCode::InvalidAttribute,
"BatchNorm tensor " + name +
" must have unit stride");
" must be rank-matched with ones in all non-feature dimensions");
return ok();
};

FUSILLI_CHECK_ERROR(check1DShape(batchnormAttr.getSCALE(), "SCALE"));
FUSILLI_CHECK_ERROR(check1DShape(batchnormAttr.getBIAS(), "BIAS"));
FUSILLI_CHECK_ERROR(check1DShape(batchnormAttr.getMEAN(), "MEAN"));
FUSILLI_CHECK_ERROR(check1DShape(batchnormAttr.getVAR(), "VAR"));
FUSILLI_CHECK_ERROR(checkChannelShape(batchnormAttr.getSCALE(), "SCALE"));
FUSILLI_CHECK_ERROR(checkChannelShape(batchnormAttr.getBIAS(), "BIAS"));
FUSILLI_CHECK_ERROR(checkChannelShape(batchnormAttr.getMEAN(), "MEAN"));
FUSILLI_CHECK_ERROR(checkChannelShape(batchnormAttr.getVAR(), "VAR"));

if (isTrainingForwardPhase()) {
FUSILLI_CHECK_ERROR(
check1DShape(batchnormAttr.getSAVED_MEAN(), "SAVED_MEAN"));
FUSILLI_CHECK_ERROR(check1DShape(batchnormAttr.getSAVED_INV_VARIANCE(),
"SAVED_INV_VARIANCE"));
checkChannelShape(batchnormAttr.getSAVED_MEAN(), "SAVED_MEAN"));
FUSILLI_CHECK_ERROR(checkChannelShape(batchnormAttr.getSAVED_INV_VARIANCE(),
"SAVED_INV_VARIANCE"));
}

return ok();
Expand Down
Loading
Loading