Skip to content

Commit 8f90620

Browse files
committed
Refactor Pow for float support, remove FP16, and cleanup parsers
This commit addresses code review feedback: - Refactor Pow kernel to use 'powf' from math.h to support floating-point exponents. - Update PowParser to allow tensor exponents instead of forcing constants. - Remove Generic FP16 support and revert types.h changes. - Remove duplicate PowParser/SqrtParser classes. - Enhance RMSNorm tests with larger shapes and non-trivial weights.
1 parent fee8470 commit 8f90620

File tree

12 files changed

+56
-147
lines changed

12 files changed

+56
-147
lines changed

Deeploy/DeeployTypes.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -325,15 +325,15 @@ def fromNode(cls, node: gs.Node):
325325
return (cls(name = node.name, shape = node.shape if not isinstance(node, gs.Constant) else node.values.shape))
326326

327327
def has_live_aliases(self, ctxt: NetworkContext) -> bool:
328-
"""Checks whether this VariableBuffer has any live ancestors, i.e. buffers that are still live and are aliased by this buffer.
328+
"""Checks whether this VariableBuffer has any live aliases, i.e. buffers that are still live and are aliased by this buffer.
329329
Parameters
330330
----------
331331
ctxt : NetworkContext
332332
Current NetworkContext
333333
Returns
334334
-------
335335
bool
336-
True if this VariableBuffer has any live ancestors, False otherwise
336+
True if this VariableBuffer has any live aliases, False otherwise
337337
"""
338338
# Do a breadth-first search across the aliasing double-linked list
339339
live = self._live
@@ -2562,10 +2562,10 @@ def codeTransform(self, verbose: CodeGenVerbosity = _NoVerbosity):
25622562
self.ctxt = layer.codeTransform(self.ctxt, verbose)
25632563
self.transformed = True
25642564

2565-
def _mapNode(self, node: gs.Node) -> Union[ONNXLayer, Any]:
2565+
def _selectEngine(self, node: gs.Node) -> DeploymentEngine:
25662566
for engine in self.Platform.engines:
25672567
if node.op in engine.Mapping:
2568-
return engine.Mapping[node.op](node)
2568+
return engine
25692569
raise RuntimeError(f"No mapping found for node {node.name} with op type {node.op}")
25702570

25712571
def _bindLayers(self):
@@ -2582,7 +2582,8 @@ def _bindLayers(self):
25822582
flatSchedule += subGraph
25832583

25842584
for node in flatSchedule:
2585-
layer = self._mapNode(node)
2585+
engine = self._selectEngine(node)
2586+
layer = engine.Mapping[node.op](node)
25862587
if isinstance(layer, ONNXLayer):
25872588
log.debug(f" {SUCCESS_MARK} Bind {node.name} to layer {layer.__class__.__name__}")
25882589
self.layerBinding[layer.node.name] = layer

Deeploy/Targets/Generic/Bindings.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from Deeploy.AbstractDataTypes import PointerClass
88
from Deeploy.CommonExtensions.CodeTransformationPasses.MemoryAllocation import ArgumentStructGeneration, \
99
MemoryManagementGeneration, MemoryPassthroughGeneration
10-
from Deeploy.CommonExtensions.DataTypes import FloatDataTypes, IntegerDataTypes, SignedIntegerDataTypes, float16_t, \
10+
from Deeploy.CommonExtensions.DataTypes import FloatDataTypes, IntegerDataTypes, SignedIntegerDataTypes, \
1111
float32_t, int8_t, int32_t, uint8_t
1212
from Deeploy.DeeployTypes import CodeTransformation, NodeBinding
1313
from Deeploy.FutureExtension.CodeTransformationPasses.FutureCodeTransformation import FutureGeneration
@@ -121,15 +121,11 @@
121121
BasicPowBindings = [
122122
NodeBinding(DummyChecker([PointerClass(float32_t), PointerClass(float32_t)], [PointerClass(float32_t)]),
123123
FloatPowTemplate.referenceTemplate, BasicTransformer),
124-
NodeBinding(DummyChecker([PointerClass(float16_t), PointerClass(float16_t)], [PointerClass(float16_t)]),
125-
FloatPowTemplate.referenceTemplate, BasicTransformer)
126124
]
127125

128126
BasicSqrtBindings = [
129127
NodeBinding(DummyChecker([PointerClass(float32_t)], [PointerClass(float32_t)]), FloatSqrtTemplate.referenceTemplate,
130128
BasicTransformer),
131-
NodeBinding(DummyChecker([PointerClass(float16_t)], [PointerClass(float16_t)]), FloatSqrtTemplate.referenceTemplate,
132-
BasicTransformer)
133129
]
134130

135131
BasicDivBindings = [

Deeploy/Targets/Generic/Parsers.py

Lines changed: 1 addition & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import numpy as np
99
import onnx_graphsurgeon as gs
1010

11-
from Deeploy.DeeployTypes import NetworkContext, NodeParser, VariableBuffer
11+
from Deeploy.DeeployTypes import NetworkContext, NodeParser, VariableBuffer, ConstantBuffer
1212

1313

1414
class ConcatParser(NodeParser):
@@ -2000,29 +2000,6 @@ def parseNodeCtxt(self,
20002000
return ctxt, True
20012001

20022002

2003-
class SqrtParser(NodeParser):
2004-
2005-
def __init__(self):
2006-
super().__init__()
2007-
2008-
def parseNode(self, node: gs.Node) -> bool:
2009-
return node.op == 'Sqrt' and len(node.inputs) == 1 and len(node.outputs) == 1
2010-
2011-
def parseNodeCtxt(self,
2012-
ctxt: NetworkContext,
2013-
node: gs.Node,
2014-
channels_first: bool = True) -> Tuple[NetworkContext, bool]:
2015-
2016-
data_in = ctxt.lookup(node.inputs[0].name)
2017-
data_out = ctxt.lookup(node.outputs[0].name)
2018-
2019-
self.operatorRepresentation['data_in'] = data_in.name
2020-
self.operatorRepresentation['data_out'] = data_out.name
2021-
self.operatorRepresentation['size'] = int(np.prod(data_in.shape))
2022-
2023-
return ctxt, True
2024-
2025-
20262003
class DivParser(NodeParser):
20272004

20282005
def __init__(self):
@@ -2808,44 +2785,6 @@ def parseNodeCtxt(self,
28082785
return ctxt, False
28092786

28102787

2811-
############################
2812-
2813-
2814-
class PowParser(NodeParser):
2815-
2816-
def __init__(self):
2817-
super().__init__()
2818-
2819-
def parseNode(self, node: gs.Node) -> bool:
2820-
return node.op == 'Pow' and len(node.inputs) == 2 and len(node.outputs) == 1
2821-
2822-
def parseNodeCtxt(self,
2823-
ctxt: NetworkContext,
2824-
node: gs.Node,
2825-
channels_first: bool = True) -> Tuple[NetworkContext, bool]:
2826-
2827-
data_in = ctxt.lookup(node.inputs[0].name)
2828-
exponent = node.inputs[1]
2829-
data_out = ctxt.lookup(node.outputs[0].name)
2830-
2831-
self.operatorRepresentation['data_in'] = data_in.name
2832-
self.operatorRepresentation['data_out'] = data_out.name
2833-
2834-
# Check if exponent is a constant
2835-
if isinstance(exponent, gs.Constant):
2836-
exp_value = float(exponent.values)
2837-
self.operatorRepresentation['exponent'] = exp_value
2838-
self.operatorRepresentation['is_constant_exp'] = True
2839-
else:
2840-
exp_tensor = ctxt.lookup(exponent.name)
2841-
self.operatorRepresentation['exponent'] = exp_tensor.name
2842-
self.operatorRepresentation['is_constant_exp'] = False
2843-
2844-
self.operatorRepresentation['size'] = int(np.prod(data_in.shape))
2845-
2846-
return ctxt, True
2847-
2848-
28492788
class SqrtParser(NodeParser):
28502789

28512790
def __init__(self):
Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,43 @@
11
# SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna
22
#
33
# SPDX-License-Identifier: Apache-2.0
4-
54
from typing import Dict, List, Tuple
6-
75
import numpy as np
8-
96
from Deeploy.DeeployTypes import NetworkContext, NodeTemplate, OperatorRepresentation
107

11-
128
class _PowTemplate(NodeTemplate):
13-
149
def alignToContext(self, ctxt: NetworkContext,
1510
operatorRepresentation: OperatorRepresentation) -> Tuple[NetworkContext, Dict, List[str]]:
16-
1711
# Get input and output tensors
1812
data_in = ctxt.lookup(operatorRepresentation['data_in'])
13+
exponent = ctxt.lookup(operatorRepresentation['exponent'])
1914
data_out = ctxt.lookup(operatorRepresentation['data_out'])
20-
21-
# Get data type (fp32 or fp16)
15+
16+
# Get data type (fp32)
2217
data_type = data_in._type.typeName
2318
operatorRepresentation['data_type'] = data_type
24-
25-
# Exponent must be a constant integer
26-
if 'exponent' in operatorRepresentation:
27-
exponent_input = operatorRepresentation['exponent']
28-
if isinstance(exponent_input, str):
29-
# It's a tensor name - not supported for integer exponent version
30-
raise ValueError("Tensor exponent not supported. Use constant integer exponent.")
31-
else:
32-
# Convert to integer
33-
operatorRepresentation['exponent_value'] = int(exponent_input)
34-
19+
3520
# Calculate size
36-
operatorRepresentation['size'] = int(np.prod(data_in.shape))
37-
21+
input_size = int(np.prod(data_in.shape))
22+
exponent_size = int(np.prod(exponent.shape))
23+
operatorRepresentation['size'] = input_size
24+
25+
# Check if exponent is scalar (broadcasting)
26+
if exponent_size == 1:
27+
operatorRepresentation['is_scalar'] = True
28+
# Get the full variable name with prefix
29+
exponent_name = operatorRepresentation['exponent']
30+
operatorRepresentation['exponent_scalar'] = f"DeeployNetwork_{exponent_name}[0]"
31+
else:
32+
operatorRepresentation['is_scalar'] = False
33+
3834
return ctxt, operatorRepresentation, []
3935

40-
4136
referenceTemplate = _PowTemplate("""
4237
// Pow (Name: ${nodeName}, Op: ${nodeOp})
43-
% if 'float32' in data_type:
44-
Pow_fp32_int32_fp32(${data_in}, ${exponent_value}, ${data_out}, ${size});
45-
% elif 'float16' in data_type:
46-
Pow_fp16_int32_fp16(${data_in}, ${exponent_value}, ${data_out}, ${size});
38+
% if is_scalar:
39+
Pow_fp32_scalar_fp32(${data_in}, ${exponent_scalar}, ${data_out}, ${size});
40+
% else:
41+
Pow_fp32_fp32_fp32(${data_in}, ${exponent}, ${data_out}, ${size});
4742
% endif
4843
""")
Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,27 @@
11
# SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna
22
#
33
# SPDX-License-Identifier: Apache-2.0
4-
54
from typing import Dict, List, Tuple
6-
75
import numpy as np
8-
96
from Deeploy.DeeployTypes import NetworkContext, NodeTemplate, OperatorRepresentation
107

11-
128
class _SqrtTemplate(NodeTemplate):
13-
149
def alignToContext(self, ctxt: NetworkContext,
1510
operatorRepresentation: OperatorRepresentation) -> Tuple[NetworkContext, Dict, List[str]]:
16-
1711
# Get input and output tensors
1812
data_in = ctxt.lookup(operatorRepresentation['data_in'])
1913
data_out = ctxt.lookup(operatorRepresentation['data_out'])
20-
21-
# Get data type (fp32 or fp16)
14+
15+
# Get data type (fp32)
2216
data_type = data_in._type.typeName
2317
operatorRepresentation['data_type'] = data_type
24-
18+
2519
# Calculate size
2620
operatorRepresentation['size'] = int(np.prod(data_in.shape))
27-
21+
2822
return ctxt, operatorRepresentation, []
2923

30-
3124
referenceTemplate = _SqrtTemplate("""
3225
// Sqrt (Name: ${nodeName}, Op: ${nodeOp})
33-
% if 'float32' in data_type:
3426
Sqrt_fp32_fp32(${data_in}, ${data_out}, ${size});
35-
% elif 'float16' in data_type:
36-
Sqrt_fp16_fp16(${data_in}, ${data_out}, ${size});
37-
% endif
3827
""")
56 KB
Binary file not shown.
763 Bytes
Binary file not shown.
56 KB
Binary file not shown.

TargetLibraries/Generic/inc/kernel/Pow.h

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,19 @@
88
* This file implements the element-wise binary power operation.
99
*/
1010

11-
/******************************************************************************/
12-
/* Power (32bit) */
13-
/******************************************************************************/
14-
1511
#ifndef __DEEPLOY_MATH_POW_KERNEL_HEADER_
1612
#define __DEEPLOY_MATH_POW_KERNEL_HEADER_
1713

1814
#include "DeeployBasicMath.h"
1915

20-
void Pow_fp32_int32_fp32(float32_t *data_in, int32_t exponent,
21-
float32_t *data_out, int32_t size);
16+
void Pow_fp32_fp32_fp32(const float32_t *__restrict__ data_in,
17+
const float32_t *__restrict__ exponent,
18+
float32_t *__restrict__ data_out,
19+
int32_t size);
20+
21+
void Pow_fp32_scalar_fp32(const float32_t *__restrict__ data_in,
22+
float32_t exponent,
23+
float32_t *__restrict__ data_out,
24+
int32_t size);
2225

23-
void Pow_fp16_int32_fp16(float16_t *data_in, int32_t exponent,
24-
float16_t *data_out, int32_t size);
2526
#endif

TargetLibraries/Generic/inc/kernel/Sqrt.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,4 @@
1919

2020
void Sqrt_fp32_fp32(float32_t *data_in, float32_t *data_out, int32_t size);
2121

22-
void Sqrt_fp16_fp16(float16_t *data_in, float16_t *data_out, int32_t size);
23-
2422
#endif //__DEEPLOY_BASIC_MATH_SQRT_KERNEL_HEADER_

0 commit comments

Comments
 (0)