Skip to content

Commit a3398a3

Browse files
rsudermanclaude
andauthored
Add BatchNorm node, attributes, and ASM emitter support (#259)
Introduces BatchNormNode and BatchnormAttr for both inference and training forward phases, with MLIR assembly emission via torch.aten.native_batch_norm. Includes unit tests and lit tests for NCHW layout. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> --------- Signed-off-by: Rob Suderman <rob.suderman@gmail.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 6165d53 commit a3398a3

18 files changed

+2189
-26
lines changed

include/fusilli.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
// Attributes / Types:
3333
#include "fusilli/attributes/attributes.h" // IWYU pragma: export
34+
#include "fusilli/attributes/batchnorm_attributes.h" // IWYU pragma: export
3435
#include "fusilli/attributes/common.h" // IWYU pragma: export
3536
#include "fusilli/attributes/conv_attributes.h" // IWYU pragma: export
3637
#include "fusilli/attributes/custom_op_attributes.h" // IWYU pragma: export
@@ -43,6 +44,7 @@
4344
#include "fusilli/attributes/types.h" // IWYU pragma: export
4445

4546
// Nodes:
47+
#include "fusilli/node/batchnorm_node.h" // IWYU pragma: export
4648
#include "fusilli/node/conv_node.h" // IWYU pragma: export
4749
#include "fusilli/node/custom_op_node.h" // IWYU pragma: export
4850
#include "fusilli/node/layernorm_node.h" // IWYU pragma: export
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
// Copyright 2026 Advanced Micro Devices, Inc.
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file contains attributes (compile-time constant metadata) for
10+
// batch normalization nodes.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef FUSILLI_ATTRIBUTES_BATCHNORM_ATTRIBUTES_H
15+
#define FUSILLI_ATTRIBUTES_BATCHNORM_ATTRIBUTES_H
16+
17+
#include "fusilli/attributes/attributes.h"
18+
#include "fusilli/attributes/common.h"
19+
#include "fusilli/attributes/tensor_attributes.h"
20+
21+
#include <cstdint>
22+
#include <memory>
23+
#include <unordered_map>
24+
25+
namespace fusilli {
26+
27+
class BatchnormAttr : public AttributesCRTP<BatchnormAttr> {
28+
public:
29+
// Names for Tensor Inputs and Outputs.
30+
enum class InputNames : uint8_t {
31+
X,
32+
SCALE,
33+
BIAS,
34+
MEAN,
35+
VAR,
36+
EPSILON,
37+
MOMENTUM
38+
};
39+
enum class OutputNames : uint8_t { Y, SAVED_MEAN, SAVED_INV_VARIANCE };
40+
41+
std::unordered_map<InputNames, std::shared_ptr<TensorAttr>> inputs;
42+
std::unordered_map<OutputNames, std::shared_ptr<TensorAttr>> outputs;
43+
44+
// Setters:
45+
FUSILLI_GENERIC_INPUT_TENSOR_SETTER(BatchnormAttr, InputNames, X)
46+
FUSILLI_GENERIC_INPUT_TENSOR_SETTER(BatchnormAttr, InputNames, SCALE)
47+
FUSILLI_GENERIC_INPUT_TENSOR_SETTER(BatchnormAttr, InputNames, BIAS)
48+
FUSILLI_GENERIC_INPUT_TENSOR_SETTER(BatchnormAttr, InputNames, MEAN)
49+
FUSILLI_GENERIC_INPUT_TENSOR_SETTER(BatchnormAttr, InputNames, VAR)
50+
FUSILLI_GENERIC_OUTPUT_TENSOR_SETTER(BatchnormAttr, OutputNames, Y)
51+
FUSILLI_GENERIC_OUTPUT_TENSOR_SETTER(BatchnormAttr, OutputNames, SAVED_MEAN)
52+
FUSILLI_GENERIC_OUTPUT_TENSOR_SETTER(BatchnormAttr, OutputNames,
53+
SAVED_INV_VARIANCE)
54+
55+
BatchnormAttr &setEpsilon(const std::shared_ptr<TensorAttr> &epsilon) {
56+
return setInput(InputNames::EPSILON, epsilon);
57+
}
58+
59+
BatchnormAttr &setMomentum(const std::shared_ptr<TensorAttr> &momentum) {
60+
return setInput(InputNames::MOMENTUM, momentum);
61+
}
62+
63+
BatchnormAttr &setForwardPhase(NormFwdPhase forwardPhase) {
64+
forwardPhase_ = forwardPhase;
65+
return *this;
66+
}
67+
68+
// Getters:
69+
FUSILLI_GENERIC_INPUT_TENSOR_GETTER(InputNames, X)
70+
FUSILLI_GENERIC_INPUT_TENSOR_GETTER(InputNames, SCALE)
71+
FUSILLI_GENERIC_INPUT_TENSOR_GETTER(InputNames, BIAS)
72+
FUSILLI_GENERIC_INPUT_TENSOR_GETTER(InputNames, MEAN)
73+
FUSILLI_GENERIC_INPUT_TENSOR_GETTER(InputNames, VAR)
74+
FUSILLI_GENERIC_OUTPUT_TENSOR_GETTER(OutputNames, Y)
75+
FUSILLI_GENERIC_OUTPUT_TENSOR_GETTER(OutputNames, SAVED_MEAN)
76+
FUSILLI_GENERIC_OUTPUT_TENSOR_GETTER(OutputNames, SAVED_INV_VARIANCE)
77+
78+
std::shared_ptr<TensorAttr> getEpsilon() const {
79+
return getInput(InputNames::EPSILON);
80+
}
81+
82+
std::shared_ptr<TensorAttr> getMomentum() const {
83+
return getInput(InputNames::MOMENTUM);
84+
}
85+
86+
NormFwdPhase getForwardPhase() const { return forwardPhase_; }
87+
88+
private:
89+
NormFwdPhase forwardPhase_ = NormFwdPhase::NOT_SET;
90+
};
91+
92+
} // namespace fusilli
93+
94+
#endif // FUSILLI_ATTRIBUTES_BATCHNORM_ATTRIBUTES_H

include/fusilli/graph/graph.h

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#ifndef FUSILLI_GRAPH_GRAPH_H
1515
#define FUSILLI_GRAPH_GRAPH_H
1616

17+
#include "fusilli/attributes/batchnorm_attributes.h"
1718
#include "fusilli/attributes/common.h"
1819
#include "fusilli/attributes/conv_attributes.h"
1920
#include "fusilli/attributes/custom_op_attributes.h"
@@ -30,6 +31,7 @@
3031
#include "fusilli/backend/compile_session.h"
3132
#include "fusilli/backend/handle.h"
3233
#include "fusilli/graph/context.h"
34+
#include "fusilli/node/batchnorm_node.h"
3335
#include "fusilli/node/conv_node.h"
3436
#include "fusilli/node/custom_op_node.h"
3537
#include "fusilli/node/layernorm_node.h"
@@ -268,6 +270,13 @@ class Graph : public INode {
268270
const std::shared_ptr<TensorAttr> &w,
269271
ConvDGradAttr &attributes);
270272
std::array<std::shared_ptr<TensorAttr>, 3>
273+
batchnorm(const std::shared_ptr<TensorAttr> &x,
274+
const std::shared_ptr<TensorAttr> &scale,
275+
const std::shared_ptr<TensorAttr> &bias,
276+
const std::shared_ptr<TensorAttr> &mean,
277+
const std::shared_ptr<TensorAttr> &var, BatchnormAttr &attributes);
278+
279+
std::array<std::shared_ptr<TensorAttr>, 3>
271280
layernorm(const std::shared_ptr<TensorAttr> &x,
272281
const std::shared_ptr<TensorAttr> &scale,
273282
const std::shared_ptr<TensorAttr> &bias, LayernormAttr &attributes);
@@ -727,6 +736,60 @@ Graph::convDGrad(const std::shared_ptr<TensorAttr> &dy,
727736
return dx;
728737
}
729738

739+
// Create a BatchNormNode, populate it with the specified attributes, create
740+
// output tensors and add the node to the graph's sub nodes.
741+
inline std::array<std::shared_ptr<TensorAttr>, 3>
742+
Graph::batchnorm(const std::shared_ptr<TensorAttr> &x,
743+
const std::shared_ptr<TensorAttr> &scale,
744+
const std::shared_ptr<TensorAttr> &bias,
745+
const std::shared_ptr<TensorAttr> &mean,
746+
const std::shared_ptr<TensorAttr> &var,
747+
BatchnormAttr &batchnormAttr) {
748+
// Populate names when not set.
749+
if (batchnormAttr.getName().empty())
750+
batchnormAttr.setName("batchnorm_" + std::to_string(subNodes_.size()));
751+
if (x && x->getName().empty())
752+
x->setName(batchnormAttr.getName() + "_X");
753+
if (scale && scale->getName().empty())
754+
scale->setName(batchnormAttr.getName() + "_SCALE");
755+
if (bias && bias->getName().empty())
756+
bias->setName(batchnormAttr.getName() + "_BIAS");
757+
if (mean && mean->getName().empty())
758+
mean->setName(batchnormAttr.getName() + "_MEAN");
759+
if (var && var->getName().empty())
760+
var->setName(batchnormAttr.getName() + "_VAR");
761+
auto eps = batchnormAttr.getEpsilon();
762+
if (eps && eps->getName().empty())
763+
eps->setName(batchnormAttr.getName() + "_EPSILON");
764+
auto mom = batchnormAttr.getMomentum();
765+
if (mom && mom->getName().empty())
766+
mom->setName(batchnormAttr.getName() + "_MOMENTUM");
767+
768+
FUSILLI_LOG_LABEL_ENDL("INFO: Adding BatchNorm '" << batchnormAttr.getName()
769+
<< "' to Graph");
770+
771+
// Set inputs.
772+
batchnormAttr.setX(x).setSCALE(scale).setBIAS(bias).setMEAN(mean).setVAR(var);
773+
774+
// Set outputs.
775+
std::shared_ptr<TensorAttr> y = outputTensor(batchnormAttr.getName() + "_Y");
776+
std::shared_ptr<TensorAttr> savedMean = nullptr;
777+
std::shared_ptr<TensorAttr> savedInvVar = nullptr;
778+
if (batchnormAttr.getForwardPhase() == NormFwdPhase::TRAINING) {
779+
savedMean = outputTensor(batchnormAttr.getName() + "_SAVED_MEAN");
780+
savedInvVar = outputTensor(batchnormAttr.getName() + "_SAVED_INV_VARIANCE");
781+
}
782+
batchnormAttr.setY(y);
783+
batchnormAttr.setSAVED_MEAN(savedMean);
784+
batchnormAttr.setSAVED_INV_VARIANCE(savedInvVar);
785+
786+
// Create node and add to Graph's subNodes_.
787+
subNodes_.emplace_back(
788+
std::make_unique<BatchNormNode>(std::move(batchnormAttr), context));
789+
790+
return {std::move(y), std::move(savedMean), std::move(savedInvVar)};
791+
}
792+
730793
// Create a LayerNormNode, populate it with the specified attributes, create
731794
// output tensors and add the node to the graph's sub nodes
732795
inline std::array<std::shared_ptr<TensorAttr>, 3>

0 commit comments

Comments
 (0)