|
33 | 33 | #include "fusilli/external/torch_types.h" |
34 | 34 | #include "fusilli/graph/graph.h" |
35 | 35 | #include "fusilli/node/conv_node.h" |
| 36 | +#include "fusilli/node/layernorm_node.h" |
36 | 37 | #include "fusilli/node/pointwise_node.h" |
37 | 38 | #include "fusilli/support/extras.h" |
38 | 39 |
|
@@ -869,6 +870,200 @@ inline std::string ConvDGradNode::emitNodePreAsm() const { |
869 | 870 | return output; |
870 | 871 | } |
871 | 872 |
|
| 873 | +//===----------------------------------------------------------------------===// |
| 874 | +// |
| 875 | +// LayerNormNode ASM Emitter Methods |
| 876 | +// |
| 877 | +//===----------------------------------------------------------------------===// |
| 878 | + |
| 879 | +// Emits LayerNormNode's operand names in MLIR assembly format. |
| 880 | +// |
| 881 | +// The unique suffix is included to ensure SSA uniqueness when the same |
| 882 | +// tensor is used by multiple operations. |
| 883 | +inline std::string LayerNormNode::getOperandNamesAsm() const { |
| 884 | + std::ostringstream oss; |
| 885 | + std::string suffix = layernormAttr.getName(); |
| 886 | + |
| 887 | + oss << layernormAttr.getX()->getValueNameAsm() << "_" << suffix << "_perm, "; |
| 888 | + oss << "%normalized_shape_" << suffix << ", "; |
| 889 | + |
| 890 | + auto getOptionalOperandNameAsm = [&](const std::shared_ptr<TensorAttr> &t, |
| 891 | + const std::string &name) { |
| 892 | + return t ? t->getValueNameAsm() + "_" + suffix + "_perm, " |
| 893 | + : "%none_" + name + "_" + suffix + ", "; |
| 894 | + }; |
| 895 | + |
| 896 | + oss << getOptionalOperandNameAsm(layernormAttr.getSCALE(), "scale"); |
| 897 | + oss << getOptionalOperandNameAsm(layernormAttr.getBIAS(), "bias"); |
| 898 | + oss << "%eps_" << suffix; |
| 899 | + |
| 900 | + return oss.str(); |
| 901 | +} |
| 902 | + |
| 903 | +// Emits LayerNormNode's operand types in MLIR assembly format. |
| 904 | +inline std::string LayerNormNode::getOperandTypesAsm() const { |
| 905 | + std::ostringstream oss; |
| 906 | + |
| 907 | + oss << layernormAttr.getX()->getTensorTypeAsm(/*isValueTensor=*/true, |
| 908 | + /*useLogicalDims=*/true) |
| 909 | + << ", "; |
| 910 | + oss << "!torch.list<int>" << ", "; |
| 911 | + |
| 912 | + auto getOptionalOperandTypeAsm = [&](const std::shared_ptr<TensorAttr> &t) { |
| 913 | + return t ? t->getTensorTypeAsm(/*isValueTensor=*/true, |
| 914 | + /*useLogicalDims=*/true) |
| 915 | + : "!torch.none"; |
| 916 | + }; |
| 917 | + |
| 918 | + oss << getOptionalOperandTypeAsm(layernormAttr.getSCALE()) << ", "; |
| 919 | + oss << getOptionalOperandTypeAsm(layernormAttr.getBIAS()) << ", "; |
| 920 | + oss << "!torch.float"; |
| 921 | + |
| 922 | + return oss.str(); |
| 923 | +} |
| 924 | + |
| 925 | +// Emits LayerNormNode's result names in MLIR assembly format. |
| 926 | +// |
| 927 | +// The unique suffix and "_perm" are included to ensure SSA uniqueness when |
| 928 | +// the same tensor is used by multiple operations. This intermediate result |
| 929 | +// is then used by the output permute. |
| 930 | +inline std::string LayerNormNode::getResultNamesAsm() const { |
| 931 | + std::ostringstream oss; |
| 932 | + std::string suffix = layernormAttr.getName(); |
| 933 | + |
| 934 | + oss << layernormAttr.getY()->getValueNameAsm() << "_" << suffix << "_perm"; |
| 935 | + |
| 936 | + if (isTrainingForwardPhase()) { |
| 937 | + oss << ", "; |
| 938 | + oss << layernormAttr.getMEAN()->getValueNameAsm() << "_" << suffix |
| 939 | + << "_perm" << ", "; |
| 940 | + oss << layernormAttr.getINV_VARIANCE()->getValueNameAsm() << "_" << suffix |
| 941 | + << "_perm"; |
| 942 | + } |
| 943 | + |
| 944 | + return oss.str(); |
| 945 | +} |
| 946 | + |
| 947 | +// Emits LayerNormNode's result types in MLIR assembly format. |
| 948 | +inline std::string LayerNormNode::getResultTypesAsm() const { |
| 949 | + std::ostringstream oss; |
| 950 | + oss << layernormAttr.getY()->getTensorTypeAsm(/*isValueTensor=*/true, |
| 951 | + /*useLogicalDims=*/true); |
| 952 | + |
| 953 | + if (isTrainingForwardPhase()) { |
| 954 | + oss << ", "; |
| 955 | + oss << layernormAttr.getMEAN()->getTensorTypeAsm(/*isValueTensor=*/true, |
| 956 | + /*useLogicalDims=*/true) |
| 957 | + << ", "; |
| 958 | + oss << layernormAttr.getINV_VARIANCE()->getTensorTypeAsm( |
| 959 | + /*isValueTensor=*/true, |
| 960 | + /*useLogicalDims=*/true); |
| 961 | + } |
| 962 | + |
| 963 | + return oss.str(); |
| 964 | +} |
| 965 | + |
| 966 | +// Get normalized_shape list construction ops in MLIR assembly format. |
| 967 | +// normalized_shape is the dimensions to normalize over (typically all dims |
| 968 | +// except batch). |
| 969 | +inline std::string LayerNormNode::getNormalizedShapeOpsAsm() const { |
| 970 | + return getListOfIntOpsAsm(getNormalizedShape(), /*prefix=*/"normalized_shape", |
| 971 | + /*suffix=*/layernormAttr.getName()); |
| 972 | +} |
| 973 | + |
| 974 | +// Get epsilon constant op in MLIR assembly format. |
| 975 | +inline std::string LayerNormNode::getEpsilonOpsAsm() const { |
| 976 | + float eps = |
| 977 | + std::get<float>(layernormAttr.getEpsilon()->getScalarValue().value()); |
| 978 | + return std::format("%eps_{} = torch.constant.float {:e}", |
| 979 | + layernormAttr.getName(), eps); |
| 980 | +} |
| 981 | + |
| 982 | +// This gets called by the recursive `emitAsmSubtree()` method to emit |
| 983 | +// the pre-assembly for each node (including the main Graph). The schema |
| 984 | +// hard-codes things that are not customizable, and leaves the rest |
| 985 | +// for template replacements using `std::format`. When modifying the |
| 986 | +// schema, take extra caution about double bracing the curly brackets |
| 987 | +// (refer to the comments at the top of this file for details). |
| 988 | +inline std::string LayerNormNode::emitNodePreAsm() const { |
| 989 | + std::string uniqueSSASuffix = layernormAttr.getName(); |
| 990 | + std::string permuteX = getPermuteOpsAsm(layernormAttr.getX(), "permute_x", |
| 991 | + uniqueSSASuffix, /*isInput=*/true); |
| 992 | + std::string permuteY = getPermuteOpsAsm(layernormAttr.getY(), "permute_y", |
| 993 | + uniqueSSASuffix, /*isInput=*/false); |
| 994 | + std::string permuteScale = |
| 995 | + layernormAttr.getSCALE() |
| 996 | + ? getPermuteOpsAsm(layernormAttr.getSCALE(), "permute_scale", |
| 997 | + uniqueSSASuffix, /*isInput=*/true) |
| 998 | + : std::format("%none_scale_{} = torch.constant.none", |
| 999 | + uniqueSSASuffix); |
| 1000 | + std::string permuteBias = |
| 1001 | + layernormAttr.getBIAS() |
| 1002 | + ? getPermuteOpsAsm(layernormAttr.getBIAS(), "permute_bias", |
| 1003 | + uniqueSSASuffix, /*isInput=*/true) |
| 1004 | + : std::format("%none_bias_{} = torch.constant.none", uniqueSSASuffix); |
| 1005 | + |
| 1006 | + if (isTrainingForwardPhase()) { |
| 1007 | + std::string permuteMean = |
| 1008 | + getPermuteOpsAsm(layernormAttr.getMEAN(), "permute_mean", |
| 1009 | + uniqueSSASuffix, /*isInput=*/false); |
| 1010 | + std::string permuteInvVariance = getPermuteOpsAsm( |
| 1011 | + layernormAttr.getINV_VARIANCE(), "permute_inv_variance", |
| 1012 | + uniqueSSASuffix, /*isInput=*/false); |
| 1013 | + |
| 1014 | + constexpr std::string_view schema = R"( |
| 1015 | + {0} |
| 1016 | + {1} |
| 1017 | + {2} |
| 1018 | + {3} |
| 1019 | + {4} |
| 1020 | + {5} = torch.aten.native_layer_norm {6} : {7} -> {8} |
| 1021 | + {9} |
| 1022 | + {10} |
| 1023 | + {11} |
| 1024 | + )"; |
| 1025 | + |
| 1026 | + return std::format(schema, |
| 1027 | + getNormalizedShapeOpsAsm(), // {0} |
| 1028 | + getEpsilonOpsAsm(), // {1} |
| 1029 | + permuteX, // {2} |
| 1030 | + permuteScale, // {3} |
| 1031 | + permuteBias, // {4} |
| 1032 | + getResultNamesAsm(), // {5} |
| 1033 | + getOperandNamesAsm(), // {6} |
| 1034 | + getOperandTypesAsm(), // {7} |
| 1035 | + getResultTypesAsm(), // {8} |
| 1036 | + permuteY, // {9} |
| 1037 | + permuteMean, // {10} |
| 1038 | + permuteInvVariance // {11} |
| 1039 | + ); |
| 1040 | + } |
| 1041 | + |
| 1042 | + constexpr std::string_view schema = R"( |
| 1043 | + {1} |
| 1044 | + {2} |
| 1045 | + {3} |
| 1046 | + {4} |
| 1047 | + {5} |
| 1048 | + %cudnn_enable_{0} = torch.constant.bool false |
| 1049 | + {6} = torch.aten.layer_norm {7}, %cudnn_enable_{0} : {8}, !torch.bool -> {9} |
| 1050 | + {10} |
| 1051 | + )"; |
| 1052 | + |
| 1053 | + return std::format(schema, uniqueSSASuffix, // {0} |
| 1054 | + getNormalizedShapeOpsAsm(), // {1} |
| 1055 | + getEpsilonOpsAsm(), // {2} |
| 1056 | + permuteX, // {3} |
| 1057 | + permuteScale, // {4} |
| 1058 | + permuteBias, // {5} |
| 1059 | + getResultNamesAsm(), // {6} |
| 1060 | + getOperandNamesAsm(), // {7} |
| 1061 | + getOperandTypesAsm(), // {8} |
| 1062 | + getResultTypesAsm(), // {9} |
| 1063 | + permuteY // {10} |
| 1064 | + ); |
| 1065 | +} |
| 1066 | + |
872 | 1067 | //===----------------------------------------------------------------------===// |
873 | 1068 | // |
874 | 1069 | // MatmulNode ASM Emitter Methods |
|
0 commit comments