|
14 | 14 | #ifndef FUSILLI_GRAPH_GRAPH_H |
15 | 15 | #define FUSILLI_GRAPH_GRAPH_H |
16 | 16 |
|
| 17 | +#include "fusilli/attributes/batchnorm_attributes.h" |
17 | 18 | #include "fusilli/attributes/common.h" |
18 | 19 | #include "fusilli/attributes/conv_attributes.h" |
19 | 20 | #include "fusilli/attributes/custom_op_attributes.h" |
|
30 | 31 | #include "fusilli/backend/compile_session.h" |
31 | 32 | #include "fusilli/backend/handle.h" |
32 | 33 | #include "fusilli/graph/context.h" |
| 34 | +#include "fusilli/node/batchnorm_node.h" |
33 | 35 | #include "fusilli/node/conv_node.h" |
34 | 36 | #include "fusilli/node/custom_op_node.h" |
35 | 37 | #include "fusilli/node/layernorm_node.h" |
@@ -268,6 +270,13 @@ class Graph : public INode { |
268 | 270 | const std::shared_ptr<TensorAttr> &w, |
269 | 271 | ConvDGradAttr &attributes); |
270 | 272 | std::array<std::shared_ptr<TensorAttr>, 3> |
| 273 | + batchnorm(const std::shared_ptr<TensorAttr> &x, |
| 274 | + const std::shared_ptr<TensorAttr> &scale, |
| 275 | + const std::shared_ptr<TensorAttr> &bias, |
| 276 | + const std::shared_ptr<TensorAttr> &mean, |
| 277 | + const std::shared_ptr<TensorAttr> &var, BatchnormAttr &attributes); |
| 278 | + |
| 279 | + std::array<std::shared_ptr<TensorAttr>, 3> |
271 | 280 | layernorm(const std::shared_ptr<TensorAttr> &x, |
272 | 281 | const std::shared_ptr<TensorAttr> &scale, |
273 | 282 | const std::shared_ptr<TensorAttr> &bias, LayernormAttr &attributes); |
@@ -727,6 +736,60 @@ Graph::convDGrad(const std::shared_ptr<TensorAttr> &dy, |
727 | 736 | return dx; |
728 | 737 | } |
729 | 738 |
|
| 739 | +// Create a BatchNormNode, populate it with the specified attributes, create |
| 740 | +// output tensors and add the node to the graph's sub nodes. |
| 741 | +inline std::array<std::shared_ptr<TensorAttr>, 3> |
| 742 | +Graph::batchnorm(const std::shared_ptr<TensorAttr> &x, |
| 743 | + const std::shared_ptr<TensorAttr> &scale, |
| 744 | + const std::shared_ptr<TensorAttr> &bias, |
| 745 | + const std::shared_ptr<TensorAttr> &mean, |
| 746 | + const std::shared_ptr<TensorAttr> &var, |
| 747 | + BatchnormAttr &batchnormAttr) { |
| 748 | + // Populate names when not set. |
| 749 | + if (batchnormAttr.getName().empty()) |
| 750 | + batchnormAttr.setName("batchnorm_" + std::to_string(subNodes_.size())); |
| 751 | + if (x && x->getName().empty()) |
| 752 | + x->setName(batchnormAttr.getName() + "_X"); |
| 753 | + if (scale && scale->getName().empty()) |
| 754 | + scale->setName(batchnormAttr.getName() + "_SCALE"); |
| 755 | + if (bias && bias->getName().empty()) |
| 756 | + bias->setName(batchnormAttr.getName() + "_BIAS"); |
| 757 | + if (mean && mean->getName().empty()) |
| 758 | + mean->setName(batchnormAttr.getName() + "_MEAN"); |
| 759 | + if (var && var->getName().empty()) |
| 760 | + var->setName(batchnormAttr.getName() + "_VAR"); |
| 761 | + auto eps = batchnormAttr.getEpsilon(); |
| 762 | + if (eps && eps->getName().empty()) |
| 763 | + eps->setName(batchnormAttr.getName() + "_EPSILON"); |
| 764 | + auto mom = batchnormAttr.getMomentum(); |
| 765 | + if (mom && mom->getName().empty()) |
| 766 | + mom->setName(batchnormAttr.getName() + "_MOMENTUM"); |
| 767 | + |
| 768 | + FUSILLI_LOG_LABEL_ENDL("INFO: Adding BatchNorm '" << batchnormAttr.getName() |
| 769 | + << "' to Graph"); |
| 770 | + |
| 771 | + // Set inputs. |
| 772 | + batchnormAttr.setX(x).setSCALE(scale).setBIAS(bias).setMEAN(mean).setVAR(var); |
| 773 | + |
| 774 | + // Set outputs. |
| 775 | + std::shared_ptr<TensorAttr> y = outputTensor(batchnormAttr.getName() + "_Y"); |
| 776 | + std::shared_ptr<TensorAttr> savedMean = nullptr; |
| 777 | + std::shared_ptr<TensorAttr> savedInvVar = nullptr; |
| 778 | + if (batchnormAttr.getForwardPhase() == NormFwdPhase::TRAINING) { |
| 779 | + savedMean = outputTensor(batchnormAttr.getName() + "_SAVED_MEAN"); |
| 780 | + savedInvVar = outputTensor(batchnormAttr.getName() + "_SAVED_INV_VARIANCE"); |
| 781 | + } |
| 782 | + batchnormAttr.setY(y); |
| 783 | + batchnormAttr.setSAVED_MEAN(savedMean); |
| 784 | + batchnormAttr.setSAVED_INV_VARIANCE(savedInvVar); |
| 785 | + |
| 786 | + // Create node and add to Graph's subNodes_. |
| 787 | + subNodes_.emplace_back( |
| 788 | + std::make_unique<BatchNormNode>(std::move(batchnormAttr), context)); |
| 789 | + |
| 790 | + return {std::move(y), std::move(savedMean), std::move(savedInvVar)}; |
| 791 | +} |
| 792 | + |
730 | 793 | // Create a LayerNormNode, populate it with the specified attributes, create |
731 | 794 | // output tensors and add the node to the graph's sub nodes |
732 | 795 | inline std::array<std::shared_ptr<TensorAttr>, 3> |
|
0 commit comments