Skip to content

Commit 363b7d5

Browse files
authored
Support RMSNorm (#136)
This PR adds support for RMSNorm (Root Mean Square Normalization) operation to the Deeploy framework's Generic platform. RMSNorm is a critical normalization technique used in modern Transformer architectures and large language models. To enable RMSNorm deployment on embedded systems, this PR implements the necessary mathematical primitives (Pow and Sqrt operations) and integrates them into Deeploy's compilation pipeline. The implementation follows Deeploy's operator decomposition approach, where RMSNorm is constructed from basic mathematical operations rather than as a monolithic kernel. This design provides flexibility and maintainability while supporting both float32 and float16 precision for resource-constrained embedded devices. ## Added - **Pow (Power) operation support** - `FloatPowTemplate.py`: Mako template for C code generation - `Pow_fp32.c` Kernel implementations for both precisions - `kernel/Pow.h`: Kernel interface definitions - Parser, Layer, and Binding classes for framework integration - **Sqrt (Square Root) operation support** - `FloatSqrtTemplate.py`: Mako template for C code generation - `Sqrt_fp32.c` : Kernel implementations - `kernel/Sqrt.h`: Kernel interface definitions - Complete framework integration components - **Comprehensive test suites** - `testFloatPow` : Pow operator tests with ONNX models and reference data - `testFloatSqrt` : Sqrt operator tests - `testFloatRMSNorm`: End-to-end RMSNorm tests demonstrating operator composition ## Changed - **Framework integration files** - `Deeploy/Targets/Generic/Parsers.py`: Added PowParser and SqrtParser for ONNX graph parsing - `Deeploy/Targets/Generic/Layers.py`: Added corresponding Layer classes for both operations - `Deeploy/Targets/Generic/Bindings.py`: Added type checking and binding registration - `Deeploy/Targets/Generic/Platform.py`: Registered new operations in platform mapping - **Runtime library headers** - `TargetLibraries/Generic/inc/DeeployBasicMath.h`: Extended with Pow and Sqrt function declarations - `TargetLibraries/Generic/inc/types.h`: Updated type definitions for consistency - **CI/CD configuration** - `.github/workflows/ci-platform-generic.yml`: Updated to include new test cases in automated testing pipeline
1 parent e07cd13 commit 363b7d5

File tree

26 files changed

+302
-15
lines changed

26 files changed

+302
-15
lines changed

.github/workflows/ci-platform-generic.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,10 @@ jobs:
7373
testFloatSoftmax
7474
testFloatTranspose
7575
testFloatMul
76+
testFloatPowScalar
77+
testFloatPowVector
78+
testFloatSqrt
79+
testFloatRMSNorm
7680
Quant
7781
Dequant
7882
QuantizedLinear

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ package-lock.json
2424
.mypy_cache
2525
node_modules
2626

27+
.venv/*
28+
2729
compile_commands.json
2830

2931
docs/_autosummary

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ This file contains the changelog for the Deeploy project. The changelog is divid
44
## Unreleased (Planned Release Target: v0.2.1)
55

66
### List of Pull Requests
7+
- Support for RMSNorm (Pow and Sqrt operators) [#136](https://github.com/pulp-platform/Deeploy/pull/136)
78
- Demo TinyViT compatibility with tiled Siracusa [#124](https://github.com/pulp-platform/Deeploy/pull/124)
89
- TinyViT on non-tiled Siracusa [#117](https://github.com/pulp-platform/Deeploy/pull/117)
910
- Support Fully Asynchronous DMAs [#114](https://github.com/pulp-platform/Deeploy/pull/114)
@@ -26,6 +27,8 @@ This file contains the changelog for the Deeploy project. The changelog is divid
2627
- Fix bias hoisting in generic GEMM with no bias [#126](https://github.com/pulp-platform/Deeploy/pull/126)
2728

2829
### Added
30+
- Support for RMSNorm operation via operator decomposition.
31+
- Added `Pow` (Power) and `Sqrt` (Square Root) operation support (Parsers, Layers, Bindings, Templates, and FP32 Kernels) for the Generic platform.
2932
- Support for input tiling for PULP FP regular and DW conv 2D.
3033
- CI tests for tiled Siracusa FP regular and DW conv 2D, with and without bias, for skip connections, and for the demo version of TinyViT.
3134
- Documentation for PULP FP regular and DW conv 2D and MatMul tile constraints.

Deeploy/Targets/Generic/Bindings.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
ConvTransposeTemplate, DebugPrintTemplate, DequantTemplate, DummyTemplate, DWConvTemplate, FloatAddTemplate, \
1616
FloatConvTemplate, FloatDivTemplate, FloatDWConvTemplate, FloatGELUTemplate, FloatGemmTemplate, \
1717
FloatLayernormTemplate, FloatMatMulTemplate, FloatMaxPoolTemplate, FloatMulTemplate, FloatPadTemplate, \
18-
FloatReduceMeanTemplate, FloatReluTemplate, FloatSoftmaxTemplate, GatherTemplate, GemmTemplate, \
19-
IntegerDivTemplate, ITAMaxTemplate, ITAPartialMaxTemplate, MatMulTemplate, MaxPoolTemplate, MulTemplate, \
20-
PadTemplate, QuantTemplate, ReduceMeanTemplate, ReduceSumTemplate, RequantShiftTemplate, ReshapeTemplate, \
21-
RQIntegerDivTemplate, RQSiGELUTemplate, SliceTemplate, TransposeTemplate, iGELUTemplate, iLayernormTemplate, \
22-
iRMSNormTemplate, iSoftmaxTemplate
18+
FloatPowTemplate, FloatReduceMeanTemplate, FloatReluTemplate, FloatSoftmaxTemplate, FloatSqrtTemplate, \
19+
GatherTemplate, GemmTemplate, IntegerDivTemplate, ITAMaxTemplate, ITAPartialMaxTemplate, MatMulTemplate, \
20+
MaxPoolTemplate, MulTemplate, PadTemplate, QuantTemplate, ReduceMeanTemplate, ReduceSumTemplate, \
21+
RequantShiftTemplate, ReshapeTemplate, RQIntegerDivTemplate, RQSiGELUTemplate, SliceTemplate, TransposeTemplate, \
22+
iGELUTemplate, iLayernormTemplate, iRMSNormTemplate, iSoftmaxTemplate
2323
from Deeploy.Targets.Generic.TypeCheckers import AddChecker, BatchNormChecker, ConcatChecker, ConvChecker, \
2424
DebugPrintChecker, DequantChecker, DivChecker, DummyChecker, GatherChecker, GELUChecker, GEMMChecker, \
2525
LayerNormChecker, MatMulChecker, MaxPoolChecker, MulChecker, PadChecker, QuantChecker, ReduceMeanChecker, \
@@ -118,6 +118,16 @@
118118
BasicTransformer)
119119
]
120120

121+
BasicPowBindings = [
122+
NodeBinding(DummyChecker([PointerClass(float32_t), PointerClass(float32_t)], [PointerClass(float32_t)]),
123+
FloatPowTemplate.referenceTemplate, BasicTransformer),
124+
]
125+
126+
BasicSqrtBindings = [
127+
NodeBinding(DummyChecker([PointerClass(float32_t)], [PointerClass(float32_t)]), FloatSqrtTemplate.referenceTemplate,
128+
BasicTransformer),
129+
]
130+
121131
BasicDivBindings = [
122132
NodeBinding(DivChecker([PointerClass(int32_t), PointerClass(int32_t)], [PointerClass(int32_t)]),
123133
IntegerDivTemplate.referenceTemplate, BasicTransformer)

Deeploy/Targets/Generic/Layers.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,18 @@ def computeOps(self):
227227
return matmul + rqs
228228

229229

230+
class PowLayer(ONNXLayer):
231+
232+
def __init__(self, maps: List[NodeMapper]):
233+
super().__init__(maps)
234+
235+
236+
class SqrtLayer(ONNXLayer):
237+
238+
def __init__(self, maps: List[NodeMapper]):
239+
super().__init__(maps)
240+
241+
230242
class DivLayer(ONNXLayer):
231243

232244
def __init__(self, maps: List[NodeMapper]):

Deeploy/Targets/Generic/Parsers.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import numpy as np
99
import onnx_graphsurgeon as gs
1010

11-
from Deeploy.DeeployTypes import NetworkContext, NodeParser, VariableBuffer
11+
from Deeploy.DeeployTypes import ConstantBuffer, NetworkContext, NodeParser, VariableBuffer
1212

1313

1414
class ConcatParser(NodeParser):
@@ -1964,6 +1964,32 @@ def parseNodeCtxt(self,
19641964
return ctxt, True
19651965

19661966

1967+
class PowParser(NodeParser):
1968+
1969+
def __init__(self):
1970+
super().__init__()
1971+
1972+
def parseNode(self, node: gs.Node) -> bool:
1973+
return node.op == 'Pow' and len(node.inputs) == 2 and len(node.outputs) == 1
1974+
1975+
def parseNodeCtxt(self,
1976+
ctxt: NetworkContext,
1977+
node: gs.Node,
1978+
channels_first: bool = True) -> Tuple[NetworkContext, bool]:
1979+
1980+
# Lookup both inputs (data and exponent)
1981+
data_in = ctxt.lookup(node.inputs[0].name)
1982+
exponent_tensor = ctxt.lookup(node.inputs[1].name)
1983+
data_out = ctxt.lookup(node.outputs[0].name)
1984+
1985+
self.operatorRepresentation['data_in'] = data_in.name
1986+
self.operatorRepresentation['exponent'] = exponent_tensor.name
1987+
self.operatorRepresentation['data_out'] = data_out.name
1988+
self.operatorRepresentation['size'] = int(np.prod(data_in.shape))
1989+
1990+
return ctxt, True
1991+
1992+
19671993
class DivParser(NodeParser):
19681994

19691995
def __init__(self):
@@ -2747,3 +2773,26 @@ def parseNodeCtxt(self,
27472773
"ch_im_out"] * self.operatorRepresentation["dim_im_out_y"]
27482774
return newCtxt, True
27492775
return ctxt, False
2776+
2777+
2778+
class SqrtParser(NodeParser):
2779+
2780+
def __init__(self):
2781+
super().__init__()
2782+
2783+
def parseNode(self, node: gs.Node) -> bool:
2784+
return node.op == 'Sqrt' and len(node.inputs) == 1 and len(node.outputs) == 1
2785+
2786+
def parseNodeCtxt(self,
2787+
ctxt: NetworkContext,
2788+
node: gs.Node,
2789+
channels_first: bool = True) -> Tuple[NetworkContext, bool]:
2790+
2791+
data_in = ctxt.lookup(node.inputs[0].name)
2792+
data_out = ctxt.lookup(node.outputs[0].name)
2793+
2794+
self.operatorRepresentation['data_in'] = data_in.name
2795+
self.operatorRepresentation['data_out'] = data_out.name
2796+
self.operatorRepresentation['size'] = int(np.prod(data_in.shape))
2797+
2798+
return ctxt, True

Deeploy/Targets/Generic/Platform.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,22 @@
1111
BasicDequantBindings, BasicDivBindings, BasicDWConv1DBinding, BasicDWConv2DBindings, BasicGatherBindings, \
1212
BasicGELUBindings, BasicGEMMBindings, BasicITAPartialSoftmaxBinding, BasicITASoftmaxBinding, \
1313
BasicLayerNormBindings, BasicMatMulBindings, BasicMaxPool1DBindings, BasicMaxPool2DBindings, BasicMulBindings, \
14-
BasicPad1DBindings, BasicPad2DBindings, BasicQuantBindings, BasicReduceMeanBindings, BasicReduceSumBindings, \
15-
BasicReluBinding, BasicReshapeBindings, BasicRQIntegerDivBinding, BasicRQSBindings, BasicRQSGELUBinding, \
16-
BasicSliceBindings, BasicSoftmaxBindings, BasicTransposeBindings, DummyBinding
14+
BasicPad1DBindings, BasicPad2DBindings, BasicPowBindings, BasicQuantBindings, BasicReduceMeanBindings, \
15+
BasicReduceSumBindings, BasicReluBinding, BasicReshapeBindings, BasicRQIntegerDivBinding, BasicRQSBindings, \
16+
BasicRQSGELUBinding, BasicSliceBindings, BasicSoftmaxBindings, BasicSqrtBindings, BasicTransposeBindings, \
17+
DummyBinding
1718
from Deeploy.Targets.Generic.Layers import AddLayer, BatchNormalizationLayer, ConcatLayer, ConvLayer, \
1819
ConvTransposeLayer, DebugPrintLayer, DequantLayer, DivLayer, GatherLayer, GELULayer, GEMMLayer, ITAMaxLayer, \
19-
LayerNormLayer, MatMulLayer, MaxPoolLayer, MulLayer, PadLayer, QuantLayer, ReduceMeanLayer, ReduceSumLayer, \
20-
ReluLayer, RequantShiftLayer, ReshapeLayer, RQIntegerDivLayer, RQSiGELULayer, SliceLayer, SoftmaxLayer, \
21-
TransposeLayer
20+
LayerNormLayer, MatMulLayer, MaxPoolLayer, MulLayer, PadLayer, PowLayer, QuantLayer, ReduceMeanLayer, \
21+
ReduceSumLayer, ReluLayer, RequantShiftLayer, ReshapeLayer, RQIntegerDivLayer, RQSiGELULayer, SliceLayer, \
22+
SoftmaxLayer, SqrtLayer, TransposeLayer
2223
from Deeploy.Targets.Generic.Parsers import AddParser, BatchNormParser, ConcatParser, ConvTranspose1DParser, \
2324
DebugParser, DequantParser, DivParser, DummyParser, FlattenParser, GatherParser, GELUParser, GenericConv1DParser, \
2425
GenericConv2DParser, GenericDWConv1DParser, GenericDWConv2DParser, GenericGEMMParser, GenericMaxPool2DParser, \
2526
IntegerDivParser, ITAMaxParser, ITAPartialMaxParser, LayerNormParser, MatMulParser, MaxPool1DParser, MulParser, \
26-
Pad1DParser, Pad2DParser, QuantParser, ReduceMeanParser, ReduceSumParser, ReluParser, RequantShiftParser, \
27-
ReshapeParser, RQIntegerDivParser, RQSiGELUParser, SliceParser, SoftmaxParser, TransposeParser, UnsqueezeParser, \
28-
iLayerNormParser, iSoftmaxParser
27+
Pad1DParser, Pad2DParser, PowParser, QuantParser, ReduceMeanParser, ReduceSumParser, ReluParser, \
28+
RequantShiftParser, ReshapeParser, RQIntegerDivParser, RQSiGELUParser, SliceParser, SoftmaxParser, SqrtParser, \
29+
TransposeParser, UnsqueezeParser, iLayerNormParser, iSoftmaxParser
2930
from Deeploy.Targets.Generic.Templates import AllocateTemplate, FreeTemplate
3031
from Deeploy.Targets.Generic.TopologyOptimizationPasses.Passes import DequantPatternPass, ExtractPaddingFromConvPass, \
3132
ExtractPaddingFromPoolPass, MatMulAddMergePass, MergeConstAddAndRequantPass, QuantPatternPass, \
@@ -52,6 +53,8 @@
5253
MaxPoolMapper = NodeMapper(GenericMaxPool2DParser(), BasicMaxPool2DBindings)
5354
MaxPool1DMapper = NodeMapper(MaxPool1DParser(), BasicMaxPool1DBindings)
5455
MulMapper = NodeMapper(MulParser(), BasicMulBindings)
56+
PowMapper = NodeMapper(PowParser(), BasicPowBindings)
57+
SqrtMapper = NodeMapper(SqrtParser(), BasicSqrtBindings)
5558
Pad1DMapper = NodeMapper(Pad1DParser(), BasicPad1DBindings)
5659
Pad2DMapper = NodeMapper(Pad2DParser(), BasicPad2DBindings)
5760
ReduceMeanMapper = NodeMapper(ReduceMeanParser(), BasicReduceMeanBindings)
@@ -98,6 +101,8 @@
98101
'MatMulInteger': MatMulLayer([MatMulMapper]),
99102
'MaxPool': MaxPoolLayer([MaxPool1DMapper, MaxPoolMapper]),
100103
'Mul': MulLayer([MulMapper]),
104+
'Pow': PowLayer([PowMapper]),
105+
'Sqrt': SqrtLayer([SqrtMapper]),
101106
'Pad': PadLayer([Pad1DMapper, Pad2DMapper]),
102107
'ReduceMean': ReduceMeanLayer([ReduceMeanMapper]),
103108
'ReduceSum': ReduceSumLayer([ReduceSumMapper]),
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
from typing import Dict, List, Tuple
5+
6+
import numpy as np
7+
8+
from Deeploy.DeeployTypes import NetworkContext, NodeTemplate, OperatorRepresentation
9+
10+
11+
class _PowTemplate(NodeTemplate):
12+
13+
def alignToContext(self, ctxt: NetworkContext,
14+
operatorRepresentation: OperatorRepresentation) -> Tuple[NetworkContext, Dict, List[str]]:
15+
# Get input and output tensors
16+
data_in = ctxt.lookup(operatorRepresentation['data_in'])
17+
exponent = ctxt.lookup(operatorRepresentation['exponent'])
18+
data_out = ctxt.lookup(operatorRepresentation['data_out'])
19+
20+
# Get data type (fp32)
21+
data_type = data_in._type.typeName
22+
operatorRepresentation['data_type'] = data_type
23+
24+
# Get type width dynamically (e.g., 32, 64)
25+
type_width = data_in._type.referencedType.typeWidth
26+
operatorRepresentation['type_width'] = type_width
27+
28+
# Calculate size
29+
input_size = int(np.prod(data_in.shape))
30+
exponent_size = int(np.prod(exponent.shape))
31+
operatorRepresentation['size'] = input_size
32+
33+
# Check if exponent is scalar (broadcasting)
34+
if exponent_size == 1:
35+
operatorRepresentation['is_scalar'] = True
36+
# Get the full variable name with prefix
37+
exponent_name = operatorRepresentation['exponent']
38+
operatorRepresentation['exponent_scalar'] = f"DeeployNetwork_{exponent_name}[0]"
39+
else:
40+
# Since currently the kernel only supports equally sized base-exponent data,
41+
# for non-scalar, let's add a size check here (length of data_in should be equal to exponent length).
42+
if input_size != exponent_size:
43+
raise ValueError(f"Pow operator mismatch: input size ({input_size}) "
44+
f"must equal exponent size ({exponent_size}) for non-scalar exponents.")
45+
46+
operatorRepresentation['is_scalar'] = False
47+
operatorRepresentation['exponent_scalar'] = "NULL"
48+
49+
return ctxt, operatorRepresentation, []
50+
51+
52+
referenceTemplate = _PowTemplate("""
53+
// Pow (Name: ${nodeName}, Op: ${nodeOp})
54+
% if is_scalar:
55+
Pow_fp${type_width}_scalar_fp${type_width}(${data_in}, ${exponent_scalar}, ${data_out}, ${size});
56+
% else:
57+
Pow_fp${type_width}_fp${type_width}_fp${type_width}(${data_in}, ${exponent}, ${data_out}, ${size});
58+
% endif
59+
""")
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
from typing import Dict, List, Tuple
5+
6+
import numpy as np
7+
8+
from Deeploy.DeeployTypes import NetworkContext, NodeTemplate, OperatorRepresentation
9+
10+
11+
class _SqrtTemplate(NodeTemplate):
12+
13+
def alignToContext(self, ctxt: NetworkContext,
14+
operatorRepresentation: OperatorRepresentation) -> Tuple[NetworkContext, Dict, List[str]]:
15+
# Get input and output tensors
16+
data_in = ctxt.lookup(operatorRepresentation['data_in'])
17+
data_out = ctxt.lookup(operatorRepresentation['data_out'])
18+
19+
# Get data type (fp32)
20+
data_type = data_in._type.typeName
21+
operatorRepresentation['data_type'] = data_type
22+
23+
type_width = data_in._type.referencedType.typeWidth
24+
operatorRepresentation['type_width'] = type_width
25+
26+
# Calculate size
27+
operatorRepresentation['size'] = int(np.prod(data_in.shape))
28+
29+
return ctxt, operatorRepresentation, []
30+
31+
32+
referenceTemplate = _SqrtTemplate("""
33+
// Sqrt (Name: ${nodeName}, Op: ${nodeOp})
34+
Sqrt_fp${type_width}_fp${type_width}(${data_in}, ${data_out}, ${size});
35+
""")
4.26 KB
Binary file not shown.

0 commit comments

Comments
 (0)