Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/ci-platform-generic.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ jobs:
testFloatSoftmax
testFloatTranspose
testFloatMul
testFloatPowScalar
testFloatPowVector
testFloatSqrt
testFloatRMSNorm
Quant
Dequant
QuantizedLinear
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ package-lock.json
.mypy_cache
node_modules

.venv/*

compile_commands.json

docs/_autosummary
Expand Down
20 changes: 15 additions & 5 deletions Deeploy/Targets/Generic/Bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
ConvTransposeTemplate, DebugPrintTemplate, DequantTemplate, DummyTemplate, DWConvTemplate, FloatAddTemplate, \
FloatConvTemplate, FloatDivTemplate, FloatDWConvTemplate, FloatGELUTemplate, FloatGemmTemplate, \
FloatLayernormTemplate, FloatMatMulTemplate, FloatMaxPoolTemplate, FloatMulTemplate, FloatPadTemplate, \
FloatReduceMeanTemplate, FloatReluTemplate, FloatSoftmaxTemplate, GatherTemplate, GemmTemplate, \
IntegerDivTemplate, ITAMaxTemplate, ITAPartialMaxTemplate, MatMulTemplate, MaxPoolTemplate, MulTemplate, \
PadTemplate, QuantTemplate, ReduceMeanTemplate, ReduceSumTemplate, RequantShiftTemplate, ReshapeTemplate, \
RQIntegerDivTemplate, RQSiGELUTemplate, SliceTemplate, TransposeTemplate, iGELUTemplate, iLayernormTemplate, \
iRMSNormTemplate, iSoftmaxTemplate
FloatPowTemplate, FloatReduceMeanTemplate, FloatReluTemplate, FloatSoftmaxTemplate, FloatSqrtTemplate, \
GatherTemplate, GemmTemplate, IntegerDivTemplate, ITAMaxTemplate, ITAPartialMaxTemplate, MatMulTemplate, \
MaxPoolTemplate, MulTemplate, PadTemplate, QuantTemplate, ReduceMeanTemplate, ReduceSumTemplate, \
RequantShiftTemplate, ReshapeTemplate, RQIntegerDivTemplate, RQSiGELUTemplate, SliceTemplate, TransposeTemplate, \
iGELUTemplate, iLayernormTemplate, iRMSNormTemplate, iSoftmaxTemplate
from Deeploy.Targets.Generic.TypeCheckers import AddChecker, BatchNormChecker, ConcatChecker, ConvChecker, \
DebugPrintChecker, DequantChecker, DivChecker, DummyChecker, GatherChecker, GELUChecker, GEMMChecker, \
LayerNormChecker, MatMulChecker, MaxPoolChecker, MulChecker, PadChecker, QuantChecker, ReduceMeanChecker, \
Expand Down Expand Up @@ -118,6 +118,16 @@
BasicTransformer)
]

BasicPowBindings = [
NodeBinding(DummyChecker([PointerClass(float32_t), PointerClass(float32_t)], [PointerClass(float32_t)]),
FloatPowTemplate.referenceTemplate, BasicTransformer),
]

BasicSqrtBindings = [
NodeBinding(DummyChecker([PointerClass(float32_t)], [PointerClass(float32_t)]), FloatSqrtTemplate.referenceTemplate,
BasicTransformer),
]

BasicDivBindings = [
NodeBinding(DivChecker([PointerClass(int32_t), PointerClass(int32_t)], [PointerClass(int32_t)]),
IntegerDivTemplate.referenceTemplate, BasicTransformer)
Expand Down
12 changes: 12 additions & 0 deletions Deeploy/Targets/Generic/Layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,18 @@ def computeOps(self):
return matmul + rqs


class PowLayer(ONNXLayer):

def __init__(self, maps: List[NodeMapper]):
super().__init__(maps)


class SqrtLayer(ONNXLayer):

def __init__(self, maps: List[NodeMapper]):
super().__init__(maps)


class DivLayer(ONNXLayer):

def __init__(self, maps: List[NodeMapper]):
Expand Down
51 changes: 50 additions & 1 deletion Deeploy/Targets/Generic/Parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np
import onnx_graphsurgeon as gs

from Deeploy.DeeployTypes import NetworkContext, NodeParser, VariableBuffer
from Deeploy.DeeployTypes import ConstantBuffer, NetworkContext, NodeParser, VariableBuffer


class ConcatParser(NodeParser):
Expand Down Expand Up @@ -1964,6 +1964,32 @@ def parseNodeCtxt(self,
return ctxt, True


class PowParser(NodeParser):

def __init__(self):
super().__init__()

def parseNode(self, node: gs.Node) -> bool:
return node.op == 'Pow' and len(node.inputs) == 2 and len(node.outputs) == 1

def parseNodeCtxt(self,
ctxt: NetworkContext,
node: gs.Node,
channels_first: bool = True) -> Tuple[NetworkContext, bool]:

# Lookup both inputs (data and exponent)
data_in = ctxt.lookup(node.inputs[0].name)
exponent_tensor = ctxt.lookup(node.inputs[1].name)
data_out = ctxt.lookup(node.outputs[0].name)

self.operatorRepresentation['data_in'] = data_in.name
self.operatorRepresentation['exponent'] = exponent_tensor.name
self.operatorRepresentation['data_out'] = data_out.name
self.operatorRepresentation['size'] = int(np.prod(data_in.shape))

return ctxt, True


class DivParser(NodeParser):

def __init__(self):
Expand Down Expand Up @@ -2747,3 +2773,26 @@ def parseNodeCtxt(self,
"ch_im_out"] * self.operatorRepresentation["dim_im_out_y"]
return newCtxt, True
return ctxt, False


class SqrtParser(NodeParser):

def __init__(self):
super().__init__()

def parseNode(self, node: gs.Node) -> bool:
return node.op == 'Sqrt' and len(node.inputs) == 1 and len(node.outputs) == 1

def parseNodeCtxt(self,
ctxt: NetworkContext,
node: gs.Node,
channels_first: bool = True) -> Tuple[NetworkContext, bool]:

data_in = ctxt.lookup(node.inputs[0].name)
data_out = ctxt.lookup(node.outputs[0].name)

self.operatorRepresentation['data_in'] = data_in.name
self.operatorRepresentation['data_out'] = data_out.name
self.operatorRepresentation['size'] = int(np.prod(data_in.shape))

return ctxt, True
23 changes: 14 additions & 9 deletions Deeploy/Targets/Generic/Platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,22 @@
BasicDequantBindings, BasicDivBindings, BasicDWConv1DBinding, BasicDWConv2DBindings, BasicGatherBindings, \
BasicGELUBindings, BasicGEMMBindings, BasicITAPartialSoftmaxBinding, BasicITASoftmaxBinding, \
BasicLayerNormBindings, BasicMatMulBindings, BasicMaxPool1DBindings, BasicMaxPool2DBindings, BasicMulBindings, \
BasicPad1DBindings, BasicPad2DBindings, BasicQuantBindings, BasicReduceMeanBindings, BasicReduceSumBindings, \
BasicReluBinding, BasicReshapeBindings, BasicRQIntegerDivBinding, BasicRQSBindings, BasicRQSGELUBinding, \
BasicSliceBindings, BasicSoftmaxBindings, BasicTransposeBindings, DummyBinding
BasicPad1DBindings, BasicPad2DBindings, BasicPowBindings, BasicQuantBindings, BasicReduceMeanBindings, \
BasicReduceSumBindings, BasicReluBinding, BasicReshapeBindings, BasicRQIntegerDivBinding, BasicRQSBindings, \
BasicRQSGELUBinding, BasicSliceBindings, BasicSoftmaxBindings, BasicSqrtBindings, BasicTransposeBindings, \
DummyBinding
from Deeploy.Targets.Generic.Layers import AddLayer, BatchNormalizationLayer, ConcatLayer, ConvLayer, \
ConvTransposeLayer, DebugPrintLayer, DequantLayer, DivLayer, GatherLayer, GELULayer, GEMMLayer, ITAMaxLayer, \
LayerNormLayer, MatMulLayer, MaxPoolLayer, MulLayer, PadLayer, QuantLayer, ReduceMeanLayer, ReduceSumLayer, \
ReluLayer, RequantShiftLayer, ReshapeLayer, RQIntegerDivLayer, RQSiGELULayer, SliceLayer, SoftmaxLayer, \
TransposeLayer
LayerNormLayer, MatMulLayer, MaxPoolLayer, MulLayer, PadLayer, PowLayer, QuantLayer, ReduceMeanLayer, \
ReduceSumLayer, ReluLayer, RequantShiftLayer, ReshapeLayer, RQIntegerDivLayer, RQSiGELULayer, SliceLayer, \
SoftmaxLayer, SqrtLayer, TransposeLayer
from Deeploy.Targets.Generic.Parsers import AddParser, BatchNormParser, ConcatParser, ConvTranspose1DParser, \
DebugParser, DequantParser, DivParser, DummyParser, FlattenParser, GatherParser, GELUParser, GenericConv1DParser, \
GenericConv2DParser, GenericDWConv1DParser, GenericDWConv2DParser, GenericGEMMParser, GenericMaxPool2DParser, \
IntegerDivParser, ITAMaxParser, ITAPartialMaxParser, LayerNormParser, MatMulParser, MaxPool1DParser, MulParser, \
Pad1DParser, Pad2DParser, QuantParser, ReduceMeanParser, ReduceSumParser, ReluParser, RequantShiftParser, \
ReshapeParser, RQIntegerDivParser, RQSiGELUParser, SliceParser, SoftmaxParser, TransposeParser, UnsqueezeParser, \
iLayerNormParser, iSoftmaxParser
Pad1DParser, Pad2DParser, PowParser, QuantParser, ReduceMeanParser, ReduceSumParser, ReluParser, \
RequantShiftParser, ReshapeParser, RQIntegerDivParser, RQSiGELUParser, SliceParser, SoftmaxParser, SqrtParser, \
TransposeParser, UnsqueezeParser, iLayerNormParser, iSoftmaxParser
from Deeploy.Targets.Generic.Templates import AllocateTemplate, FreeTemplate
from Deeploy.Targets.Generic.TopologyOptimizationPasses.Passes import DequantPatternPass, ExtractPaddingFromConvPass, \
ExtractPaddingFromPoolPass, MatMulAddMergePass, MergeConstAddAndRequantPass, QuantPatternPass, \
Expand All @@ -52,6 +53,8 @@
MaxPoolMapper = NodeMapper(GenericMaxPool2DParser(), BasicMaxPool2DBindings)
MaxPool1DMapper = NodeMapper(MaxPool1DParser(), BasicMaxPool1DBindings)
MulMapper = NodeMapper(MulParser(), BasicMulBindings)
PowMapper = NodeMapper(PowParser(), BasicPowBindings)
SqrtMapper = NodeMapper(SqrtParser(), BasicSqrtBindings)
Pad1DMapper = NodeMapper(Pad1DParser(), BasicPad1DBindings)
Pad2DMapper = NodeMapper(Pad2DParser(), BasicPad2DBindings)
ReduceMeanMapper = NodeMapper(ReduceMeanParser(), BasicReduceMeanBindings)
Expand Down Expand Up @@ -98,6 +101,8 @@
'MatMulInteger': MatMulLayer([MatMulMapper]),
'MaxPool': MaxPoolLayer([MaxPool1DMapper, MaxPoolMapper]),
'Mul': MulLayer([MulMapper]),
'Pow': PowLayer([PowMapper]),
'Sqrt': SqrtLayer([SqrtMapper]),
'Pad': PadLayer([Pad1DMapper, Pad2DMapper]),
'ReduceMean': ReduceMeanLayer([ReduceMeanMapper]),
'ReduceSum': ReduceSumLayer([ReduceSumMapper]),
Expand Down
59 changes: 59 additions & 0 deletions Deeploy/Targets/Generic/Templates/FloatPowTemplate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna
#
# SPDX-License-Identifier: Apache-2.0
from typing import Dict, List, Tuple

import numpy as np

from Deeploy.DeeployTypes import NetworkContext, NodeTemplate, OperatorRepresentation


class _PowTemplate(NodeTemplate):

def alignToContext(self, ctxt: NetworkContext,
operatorRepresentation: OperatorRepresentation) -> Tuple[NetworkContext, Dict, List[str]]:
# Get input and output tensors
data_in = ctxt.lookup(operatorRepresentation['data_in'])
exponent = ctxt.lookup(operatorRepresentation['exponent'])
data_out = ctxt.lookup(operatorRepresentation['data_out'])

# Get data type (fp32)
data_type = data_in._type.typeName
operatorRepresentation['data_type'] = data_type

# Get type width dynamically (e.g., 32, 64)
type_width = data_in._type.referencedType.typeWidth
operatorRepresentation['type_width'] = type_width

# Calculate size
input_size = int(np.prod(data_in.shape))
exponent_size = int(np.prod(exponent.shape))
operatorRepresentation['size'] = input_size

# Check if exponent is scalar (broadcasting)
if exponent_size == 1:
operatorRepresentation['is_scalar'] = True
# Get the full variable name with prefix
exponent_name = operatorRepresentation['exponent']
operatorRepresentation['exponent_scalar'] = f"DeeployNetwork_{exponent_name}[0]"
else:
# Since currently the kernel only supports equally sized base-exponent data,
# for non-scalar, let's add a size check here (length of data_in should be equal to exponent length).
if input_size != exponent_size:
raise ValueError(f"Pow operator mismatch: input size ({input_size}) "
f"must equal exponent size ({exponent_size}) for non-scalar exponents.")

operatorRepresentation['is_scalar'] = False
operatorRepresentation['exponent_scalar'] = "NULL"

return ctxt, operatorRepresentation, []


referenceTemplate = _PowTemplate("""
// Pow (Name: ${nodeName}, Op: ${nodeOp})
% if is_scalar:
Pow_fp${type_width}_scalar_fp${type_width}(${data_in}, ${exponent_scalar}, ${data_out}, ${size});
% else:
Pow_fp${type_width}_fp${type_width}_fp${type_width}(${data_in}, ${exponent}, ${data_out}, ${size});
% endif
""")
35 changes: 35 additions & 0 deletions Deeploy/Targets/Generic/Templates/FloatSqrtTemplate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna
#
# SPDX-License-Identifier: Apache-2.0
from typing import Dict, List, Tuple

import numpy as np

from Deeploy.DeeployTypes import NetworkContext, NodeTemplate, OperatorRepresentation


class _SqrtTemplate(NodeTemplate):

def alignToContext(self, ctxt: NetworkContext,
operatorRepresentation: OperatorRepresentation) -> Tuple[NetworkContext, Dict, List[str]]:
# Get input and output tensors
data_in = ctxt.lookup(operatorRepresentation['data_in'])
data_out = ctxt.lookup(operatorRepresentation['data_out'])

# Get data type (fp32)
data_type = data_in._type.typeName
operatorRepresentation['data_type'] = data_type

type_width = data_in._type.referencedType.typeWidth
operatorRepresentation['type_width'] = type_width

# Calculate size
operatorRepresentation['size'] = int(np.prod(data_in.shape))

return ctxt, operatorRepresentation, []


referenceTemplate = _SqrtTemplate("""
// Sqrt (Name: ${nodeName}, Op: ${nodeOp})
Sqrt_fp${type_width}_fp${type_width}(${data_in}, ${data_out}, ${size});
""")
Binary file added DeeployTest/Tests/testFloatPow/inputs.npz
Binary file not shown.
Binary file added DeeployTest/Tests/testFloatPow/network.onnx
Binary file not shown.
Binary file added DeeployTest/Tests/testFloatPowScalar/inputs.npz
Binary file not shown.
Binary file added DeeployTest/Tests/testFloatPowScalar/network.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
23 changes: 23 additions & 0 deletions DeeployTest/Tests/testFloatPowVector/network.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@

deeploy_test_generator:�
3
data_in
exponentdata_outPow_Vector_Test"Powtest_float_pow_vectorZ!
data_in




Z"
exponent




b"
data_out




B
Binary file not shown.
Binary file added DeeployTest/Tests/testFloatRMSNorm/inputs.npz
Binary file not shown.
Binary file added DeeployTest/Tests/testFloatRMSNorm/network.onnx
Binary file not shown.
Binary file added DeeployTest/Tests/testFloatRMSNorm/outputs.npz
Binary file not shown.
Binary file added DeeployTest/Tests/testFloatSqrt/inputs.npz
Binary file not shown.
Binary file added DeeployTest/Tests/testFloatSqrt/network.onnx
Binary file not shown.
Binary file added DeeployTest/Tests/testFloatSqrt/outputs.npz
Binary file not shown.
2 changes: 2 additions & 0 deletions TargetLibraries/Generic/inc/DeeployBasicMath.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,14 @@
#include "kernel/MatMul.h"
#include "kernel/MaxPool.h"
#include "kernel/MaxPool1d.h"
#include "kernel/Pow.h"
#include "kernel/RMSNorm.h"
#include "kernel/RQDiv.h"
#include "kernel/RQGELU.h"
#include "kernel/RQHardswish.h"
#include "kernel/Relu.h"
#include "kernel/RequantShift.h"
#include "kernel/Softmax.h"
#include "kernel/Sqrt.h"

#endif //__DEEPLOY_BASIC_MATH_HEADER_
24 changes: 24 additions & 0 deletions TargetLibraries/Generic/inc/kernel/Pow.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna
*
* SPDX-License-Identifier: Apache-2.0
*/

/*
* This file implements the element-wise binary power operation.
*/

#ifndef __DEEPLOY_MATH_POW_KERNEL_HEADER_
#define __DEEPLOY_MATH_POW_KERNEL_HEADER_

#include "DeeployBasicMath.h"

void Pow_fp32_fp32_fp32(const float32_t *__restrict__ data_in,
const float32_t *__restrict__ exponent,
float32_t *__restrict__ data_out, int32_t size);

void Pow_fp32_scalar_fp32(const float32_t *__restrict__ data_in,
float32_t exponent, float32_t *__restrict__ data_out,
int32_t size);

#endif
22 changes: 22 additions & 0 deletions TargetLibraries/Generic/inc/kernel/Sqrt.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* SPDX-FileCopyrightText: 2020 ETH Zurich and University of Bologna
*
* SPDX-License-Identifier: Apache-2.0
*/

#ifndef __DEEPLOY_BASIC_MATH_SQRT_KERNEL_HEADER_
#define __DEEPLOY_BASIC_MATH_SQRT_KERNEL_HEADER_

#include "DeeployBasicMath.h"

/*
* Square root operation - computes sqrt for each element
*/

/******************************************************************************/
/* Sqrt */
/******************************************************************************/

void Sqrt_fp32_fp32(float32_t *data_in, float32_t *data_out, int32_t size);

#endif //__DEEPLOY_BASIC_MATH_SQRT_KERNEL_HEADER_
Loading