Skip to content

Commit dc4810f

Browse files
[Fusilli] Added ASM emitter for Layernorm node (#71)
The third part of #45 --------- Signed-off-by: Alexandra Sidorova <[email protected]> Co-authored-by: Sambhav Jain <[email protected]>
1 parent b395937 commit dc4810f

17 files changed

+1647
-0
lines changed

include/fusilli/attributes/tensor_attributes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,8 @@ class TensorAttr {
647647

648648
// To represent scalar constants either obtained through
649649
// constant folding, or passed in as scalars during execution.
650+
// These constants are inlined in the ASM emitters,
651+
// so they should not be in the variant pack.
650652
bool isScalar_ = false;
651653
std::optional<scalar_t> scalarValue_ = std::nullopt;
652654
};

include/fusilli/backend/runtime.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,14 @@ inline ErrorObject Graph::execute(
242242

243243
// Populate input buffers.
244244
for (const auto &input : fullGraphInputsSorted_) {
245+
// Scalar constants should not be used in the variantPack.
246+
if (input->isScalar()) {
247+
FUSILLI_RETURN_ERROR_IF(variantPack.contains(input),
248+
ErrorCode::VariantPackError,
249+
"Scalar constant tensor found in variantPack");
250+
continue;
251+
}
252+
245253
FUSILLI_RETURN_ERROR_IF(!variantPack.contains(input), // C++20
246254
ErrorCode::VariantPackError,
247255
"Input tensor missing from variantPack");

include/fusilli/node/layernorm_node.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,15 @@ class LayerNormNode : public NodeCRTP<LayerNormNode> {
4141
LayerNormNode(LayernormAttr &&attr, const Context &ctx)
4242
: NodeCRTP(ctx), layernormAttr(std::move(attr)) {}
4343

44+
// ASM emitter methods.
45+
std::string emitNodePreAsm() const override final;
46+
std::string getOperandNamesAsm() const;
47+
std::string getOperandTypesAsm() const;
48+
std::string getResultNamesAsm() const;
49+
std::string getResultTypesAsm() const;
50+
std::string getNormalizedShapeOpsAsm() const;
51+
std::string getEpsilonOpsAsm() const;
52+
4453
const std::string &getName() const override final {
4554
return layernormAttr.getName();
4655
}
@@ -257,6 +266,15 @@ class LayerNormNode : public NodeCRTP<LayerNormNode> {
257266
return layernormAttr.getForwardPhase() == NormFwdPhase::TRAINING;
258267
}
259268

269+
// Returns the shape over which normalization is applied:
270+
// the input tensor's shape excluding the batch dimension (dim 0),
271+
// as normalization is computed independently for each sample in the batch.
272+
std::vector<int64_t> getNormalizedShape() const {
273+
const std::vector<int64_t> &xDim = layernormAttr.getX()->getDim();
274+
std::vector<int64_t> normalizedShape(xDim.cbegin() + 1, xDim.cend());
275+
return normalizedShape;
276+
}
277+
260278
std::pair<std::vector<int64_t>, std::vector<int64_t>>
261279
getTrainingForwardOutputDimAndStride(const std::vector<int64_t> &xDim) const {
262280
// The MEAN and INV_VARIANCE tensors have shape [B, 1, ..., 1]

include/fusilli/support/asm_emitter.h

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "fusilli/external/torch_types.h"
3434
#include "fusilli/graph/graph.h"
3535
#include "fusilli/node/conv_node.h"
36+
#include "fusilli/node/layernorm_node.h"
3637
#include "fusilli/node/pointwise_node.h"
3738
#include "fusilli/support/extras.h"
3839

@@ -869,6 +870,200 @@ inline std::string ConvDGradNode::emitNodePreAsm() const {
869870
return output;
870871
}
871872

873+
//===----------------------------------------------------------------------===//
874+
//
875+
// LayerNormNode ASM Emitter Methods
876+
//
877+
//===----------------------------------------------------------------------===//
878+
879+
// Emits LayerNormNode's operand names in MLIR assembly format.
880+
//
881+
// The unique suffix is included to ensure SSA uniqueness when the same
882+
// tensor is used by multiple operations.
883+
inline std::string LayerNormNode::getOperandNamesAsm() const {
884+
std::ostringstream oss;
885+
std::string suffix = layernormAttr.getName();
886+
887+
oss << layernormAttr.getX()->getValueNameAsm() << "_" << suffix << "_perm, ";
888+
oss << "%normalized_shape_" << suffix << ", ";
889+
890+
auto getOptionalOperandNameAsm = [&](const std::shared_ptr<TensorAttr> &t,
891+
const std::string &name) {
892+
return t ? t->getValueNameAsm() + "_" + suffix + "_perm, "
893+
: "%none_" + name + "_" + suffix + ", ";
894+
};
895+
896+
oss << getOptionalOperandNameAsm(layernormAttr.getSCALE(), "scale");
897+
oss << getOptionalOperandNameAsm(layernormAttr.getBIAS(), "bias");
898+
oss << "%eps_" << suffix;
899+
900+
return oss.str();
901+
}
902+
903+
// Emits LayerNormNode's operand types in MLIR assembly format.
904+
inline std::string LayerNormNode::getOperandTypesAsm() const {
905+
std::ostringstream oss;
906+
907+
oss << layernormAttr.getX()->getTensorTypeAsm(/*isValueTensor=*/true,
908+
/*useLogicalDims=*/true)
909+
<< ", ";
910+
oss << "!torch.list<int>" << ", ";
911+
912+
auto getOptionalOperandTypeAsm = [&](const std::shared_ptr<TensorAttr> &t) {
913+
return t ? t->getTensorTypeAsm(/*isValueTensor=*/true,
914+
/*useLogicalDims=*/true)
915+
: "!torch.none";
916+
};
917+
918+
oss << getOptionalOperandTypeAsm(layernormAttr.getSCALE()) << ", ";
919+
oss << getOptionalOperandTypeAsm(layernormAttr.getBIAS()) << ", ";
920+
oss << "!torch.float";
921+
922+
return oss.str();
923+
}
924+
925+
// Emits LayerNormNode's result names in MLIR assembly format.
926+
//
927+
// The unique suffix and "_perm" are included to ensure SSA uniqueness when
928+
// the same tensor is used by multiple operations. This intermediate result
929+
// is then used by the output permute.
930+
inline std::string LayerNormNode::getResultNamesAsm() const {
931+
std::ostringstream oss;
932+
std::string suffix = layernormAttr.getName();
933+
934+
oss << layernormAttr.getY()->getValueNameAsm() << "_" << suffix << "_perm";
935+
936+
if (isTrainingForwardPhase()) {
937+
oss << ", ";
938+
oss << layernormAttr.getMEAN()->getValueNameAsm() << "_" << suffix
939+
<< "_perm" << ", ";
940+
oss << layernormAttr.getINV_VARIANCE()->getValueNameAsm() << "_" << suffix
941+
<< "_perm";
942+
}
943+
944+
return oss.str();
945+
}
946+
947+
// Emits LayerNormNode's result types in MLIR assembly format.
948+
inline std::string LayerNormNode::getResultTypesAsm() const {
949+
std::ostringstream oss;
950+
oss << layernormAttr.getY()->getTensorTypeAsm(/*isValueTensor=*/true,
951+
/*useLogicalDims=*/true);
952+
953+
if (isTrainingForwardPhase()) {
954+
oss << ", ";
955+
oss << layernormAttr.getMEAN()->getTensorTypeAsm(/*isValueTensor=*/true,
956+
/*useLogicalDims=*/true)
957+
<< ", ";
958+
oss << layernormAttr.getINV_VARIANCE()->getTensorTypeAsm(
959+
/*isValueTensor=*/true,
960+
/*useLogicalDims=*/true);
961+
}
962+
963+
return oss.str();
964+
}
965+
966+
// Get normalized_shape list construction ops in MLIR assembly format.
967+
// normalized_shape is the dimensions to normalize over (typically all dims
968+
// except batch).
969+
inline std::string LayerNormNode::getNormalizedShapeOpsAsm() const {
970+
return getListOfIntOpsAsm(getNormalizedShape(), /*prefix=*/"normalized_shape",
971+
/*suffix=*/layernormAttr.getName());
972+
}
973+
974+
// Get epsilon constant op in MLIR assembly format.
975+
inline std::string LayerNormNode::getEpsilonOpsAsm() const {
976+
float eps =
977+
std::get<float>(layernormAttr.getEpsilon()->getScalarValue().value());
978+
return std::format("%eps_{} = torch.constant.float {:e}",
979+
layernormAttr.getName(), eps);
980+
}
981+
982+
// This gets called by the recursive `emitAsmSubtree()` method to emit
983+
// the pre-assembly for each node (including the main Graph). The schema
984+
// hard-codes things that are not customizable, and leaves the rest
985+
// for template replacements using `std::format`. When modifying the
986+
// schema, take extra caution about double bracing the curly brackets
987+
// (refer to the comments at the top of this file for details).
988+
inline std::string LayerNormNode::emitNodePreAsm() const {
989+
std::string uniqueSSASuffix = layernormAttr.getName();
990+
std::string permuteX = getPermuteOpsAsm(layernormAttr.getX(), "permute_x",
991+
uniqueSSASuffix, /*isInput=*/true);
992+
std::string permuteY = getPermuteOpsAsm(layernormAttr.getY(), "permute_y",
993+
uniqueSSASuffix, /*isInput=*/false);
994+
std::string permuteScale =
995+
layernormAttr.getSCALE()
996+
? getPermuteOpsAsm(layernormAttr.getSCALE(), "permute_scale",
997+
uniqueSSASuffix, /*isInput=*/true)
998+
: std::format("%none_scale_{} = torch.constant.none",
999+
uniqueSSASuffix);
1000+
std::string permuteBias =
1001+
layernormAttr.getBIAS()
1002+
? getPermuteOpsAsm(layernormAttr.getBIAS(), "permute_bias",
1003+
uniqueSSASuffix, /*isInput=*/true)
1004+
: std::format("%none_bias_{} = torch.constant.none", uniqueSSASuffix);
1005+
1006+
if (isTrainingForwardPhase()) {
1007+
std::string permuteMean =
1008+
getPermuteOpsAsm(layernormAttr.getMEAN(), "permute_mean",
1009+
uniqueSSASuffix, /*isInput=*/false);
1010+
std::string permuteInvVariance = getPermuteOpsAsm(
1011+
layernormAttr.getINV_VARIANCE(), "permute_inv_variance",
1012+
uniqueSSASuffix, /*isInput=*/false);
1013+
1014+
constexpr std::string_view schema = R"(
1015+
{0}
1016+
{1}
1017+
{2}
1018+
{3}
1019+
{4}
1020+
{5} = torch.aten.native_layer_norm {6} : {7} -> {8}
1021+
{9}
1022+
{10}
1023+
{11}
1024+
)";
1025+
1026+
return std::format(schema,
1027+
getNormalizedShapeOpsAsm(), // {0}
1028+
getEpsilonOpsAsm(), // {1}
1029+
permuteX, // {2}
1030+
permuteScale, // {3}
1031+
permuteBias, // {4}
1032+
getResultNamesAsm(), // {5}
1033+
getOperandNamesAsm(), // {6}
1034+
getOperandTypesAsm(), // {7}
1035+
getResultTypesAsm(), // {8}
1036+
permuteY, // {9}
1037+
permuteMean, // {10}
1038+
permuteInvVariance // {11}
1039+
);
1040+
}
1041+
1042+
constexpr std::string_view schema = R"(
1043+
{1}
1044+
{2}
1045+
{3}
1046+
{4}
1047+
{5}
1048+
%cudnn_enable_{0} = torch.constant.bool false
1049+
{6} = torch.aten.layer_norm {7}, %cudnn_enable_{0} : {8}, !torch.bool -> {9}
1050+
{10}
1051+
)";
1052+
1053+
return std::format(schema, uniqueSSASuffix, // {0}
1054+
getNormalizedShapeOpsAsm(), // {1}
1055+
getEpsilonOpsAsm(), // {2}
1056+
permuteX, // {3}
1057+
permuteScale, // {4}
1058+
permuteBias, // {5}
1059+
getResultNamesAsm(), // {6}
1060+
getOperandNamesAsm(), // {7}
1061+
getOperandTypesAsm(), // {8}
1062+
getResultTypesAsm(), // {9}
1063+
permuteY // {10}
1064+
);
1065+
}
1066+
8721067
//===----------------------------------------------------------------------===//
8731068
//
8741069
// MatmulNode ASM Emitter Methods

samples/CMakeLists.txt

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,18 @@ add_fusilli_samples(
5454
libutils
5555
Catch2::Catch2WithMain
5656
)
57+
58+
add_fusilli_samples(
59+
PREFIX fusilli_layernorm_samples
60+
SRCS
61+
layernorm/layernorm_infer_nchw.cpp
62+
layernorm/layernorm_infer_nchw_scale_bias.cpp
63+
layernorm/layernorm_infer_nhwc_scale_bias.cpp
64+
layernorm/layernorm_train_nchw.cpp
65+
layernorm/layernorm_train_nchw_scale_bias.cpp
66+
layernorm/layernorm_train_nhwc_scale_bias.cpp
67+
DEPS
68+
libfusilli
69+
libutils
70+
Catch2::Catch2WithMain
71+
)

0 commit comments

Comments
 (0)