Skip to content

Commit ecae48a

Browse files
authored
CCT Attention Training on Siracusa (#69)
This PR introduces gradient operator support, improved GEMM performance, and updates to the CCT training workflow. It also includes fixes to tile constraints and naming consistency in transpose pass. ## Added - Implemented LayerNormGrad and GeluGrad operator parser, binding, template, and tile constraints. - Added CCT linear probing, LoRA, and full backpropagation training graph. ## Changed - Optimized float GEMM kernel with loop unrolling and improved transpose handling. ## Fixed - Corrected float GEMM tile constraints and templates for no bias. - Fixed transpose splitting pass logic: updated repeated naming by appending new identifiers derived from source and destination nodes.
1 parent 363b7d5 commit ecae48a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+1161
-127
lines changed

.github/workflows/ci-platform-siracusa-tiled.yml

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,7 @@ jobs:
135135
- name: "MLPerf/AnomalyDetection"
136136
L1: [64000]
137137
- name: "CCT/CCT_1_16_16_8"
138-
L1: [2000, 64000]
139-
- name: "testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_8"
140-
L1: [4000, 64000]
138+
L1: [64000]
141139
- name: "testFloatDemoTinyViT"
142140
L1: [4000]
143141
num-cores: [8]
@@ -168,9 +166,9 @@ jobs:
168166
- name: "microLlama/microLlama1"
169167
L1: [60000, 10000, 5000]
170168
- name: "CCT/CCT_2_32_32_128"
171-
L1: [64000, 128000]
172-
- name: "testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_128"
173-
L1: [32000, 64000]
169+
L1: [128000]
170+
- name: "testTrainCCT/CCT2_FT2"
171+
L1: [128000]
174172
- name: "testFloatDemoTinyViT"
175173
L1: [4000]
176174
num-cores: [8]
@@ -208,9 +206,9 @@ jobs:
208206
- name: "microLlama/microLlama8_parallel"
209207
L1: [60000, 20000, 10000]
210208
- name: "CCT/CCT_2_32_32_128"
211-
L1: [64000, 128000]
212-
- name: "testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_128"
213-
L1: [8000, 64000]
209+
L1: [128000]
210+
- name: "testTrainCCT/CCT2_FT2"
211+
L1: [128000]
214212
- name: "testFloatDemoTinyViT"
215213
L1: [4000]
216214
num-cores: [8]

.github/workflows/ci-platform-siracusa.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,5 @@ jobs:
9595
MLPerf/AnomalyDetection
9696
CCT/CCT_1_16_16_8
9797
CCT/CCT_2_32_32_128_Opset20
98-
testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_8
9998
testFloatDemoTinyViT
10099
num-cores: 8

Deeploy/Targets/Generic/Layers.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,18 @@ def computeOps(self):
5858
return mul1 + neg + exp + add + div + mul2
5959

6060

61+
class GELUGradLayer(ONNXLayer):
62+
63+
def __init__(self, maps: List[NodeMapper]):
64+
super().__init__(maps)
65+
66+
def computeOps(self):
67+
size = self.mapper.parser.operatorRepresentation['size']
68+
ops_per_element = 9
69+
gelu_grad_ops = size * ops_per_element
70+
return gelu_grad_ops
71+
72+
6173
class iHardswishLayer(ONNXLayer):
6274

6375
def __init__(self, maps: List[NodeMapper]):
@@ -450,6 +462,12 @@ def computeOps(self):
450462
return compAverage + compNormalize + compSqr + compSum + compSqrt + compDiv
451463

452464

465+
class LayerNormGradLayer(ONNXLayer):
466+
467+
def __init__(self, maps: List[NodeMapper]):
468+
super().__init__(maps)
469+
470+
453471
class TransposeLayer(ONNXLayer):
454472

455473
def __init__(self, maps: List[NodeMapper]):

Deeploy/Targets/Generic/Parsers.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -770,6 +770,33 @@ def parseNodeCtxt(self,
770770
return ctxt, True
771771

772772

773+
class GELUGradParser(NodeParser):
774+
775+
def __init__(self):
776+
super().__init__()
777+
778+
def parseNode(self, node: gs.Node) -> bool:
779+
780+
ret = all([len(node.inputs) == 2, len(node.outputs) == 1])
781+
return ret
782+
783+
def parseNodeCtxt(self,
784+
ctxt: NetworkContext,
785+
node: gs.Node,
786+
channels_first: bool = True) -> Tuple[NetworkContext, bool]:
787+
788+
upstream_grad = ctxt.lookup(node.inputs[0].name)
789+
gelu_input = ctxt.lookup(node.inputs[1].name)
790+
gelu_grad = ctxt.lookup(node.outputs[0].name)
791+
792+
self.operatorRepresentation['grad_in'] = upstream_grad.name
793+
self.operatorRepresentation['data_in'] = gelu_input.name
794+
self.operatorRepresentation['grad_out'] = gelu_grad.name
795+
self.operatorRepresentation['size'] = np.prod(upstream_grad.shape)
796+
797+
return ctxt, True
798+
799+
773800
class RQSiGELUParser(GELUParser):
774801

775802
def __init__(self):
@@ -1647,6 +1674,36 @@ def parseNodeCtxt(self,
16471674
return ctxt, True
16481675

16491676

1677+
class LayerNormGradParser(iLayerNormParser):
1678+
1679+
def parseNode(self, node: gs.Node) -> (bool):
1680+
1681+
ret = all(['epsilon' in node.attrs, len(node.inputs) == 4, len(node.outputs) == 1])
1682+
1683+
if ret:
1684+
self.operatorRepresentation['epsilon'] = node.attrs['epsilon']
1685+
1686+
return ret
1687+
1688+
def parseNodeCtxt(self,
1689+
ctxt: NetworkContext,
1690+
node: gs.Node,
1691+
channels_first: bool = True) -> Tuple[NetworkContext, bool]:
1692+
1693+
inputs = ['grad_in', 'data_in', 'weight', 'bias']
1694+
outputs = ['grad_out']
1695+
1696+
for idx, inputNode in enumerate(node.inputs):
1697+
self.operatorRepresentation[inputs[idx]] = ctxt.lookup(inputNode.name).name
1698+
for idx, outputNode in enumerate(node.outputs):
1699+
self.operatorRepresentation[outputs[idx]] = ctxt.lookup(outputNode.name).name
1700+
1701+
self.operatorRepresentation['size'] = np.prod(ctxt.lookup(node.inputs[0].name).shape)
1702+
self.operatorRepresentation['lastDimLength'] = ctxt.lookup(node.inputs[0].name).shape[-1]
1703+
1704+
return ctxt, True
1705+
1706+
16501707
class MatMulParser(NodeParser):
16511708

16521709
def __init__(self, noBiasHoisting = True):

Deeploy/Targets/Generic/TopologyOptimizationPasses/Passes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -676,8 +676,8 @@ def _split_transposes_fun(graph: gs.Graph, match: Match, name: str):
676676
inputNode.outputs = [postSplitOutput]
677677

678678
for node in originalNode.outputs.copy():
679-
nodeName = node.name + f"_transpose_in"
680-
varName = node.name + f"_transpose_in_var"
679+
nodeName = f"{t1.name}_{node.name}_transpose_in"
680+
varName = f"{t1.name}_{node.name}_transpose_in_var"
681681
newOutput = gs.Variable(name = varName, dtype = np.float32, shape = t1.outputs[0].shape)
682682

683683
transposeNode = gs.Node(name = nodeName,

Deeploy/Targets/PULPOpen/Bindings.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,10 +415,22 @@
415415
PointerClass(float32_t)], [PointerClass(float32_t)]), FloatLayernormTemplate.referenceTemplate,
416416
ForkTransformer)
417417

418+
PULPLayernormGradBinding = NodeBinding(
419+
LayerNormChecker(
420+
[PointerClass(float32_t),
421+
PointerClass(float32_t),
422+
PointerClass(float32_t),
423+
PointerClass(float32_t)], [PointerClass(float32_t)]), FloatLayernormTemplate.referenceGradTemplate,
424+
ForkTransformer)
425+
418426
PULPFloatGELUBinding = NodeBinding(
419427
GELUChecker([PointerClass(float32_t), PointerClass(float32_t)], [PointerClass(float32_t)]),
420428
FloatGELUTemplate.referenceTemplate, ForkTransformer)
421429

430+
PULPFloatGELUGradBinding = NodeBinding(
431+
GELUChecker([PointerClass(float32_t), PointerClass(float32_t)], [PointerClass(float32_t)]),
432+
FloatGELUTemplate.referenceGradTemplate, ForkTransformer)
433+
422434
PULPGatherBindings = [
423435
NodeBinding(GatherChecker([PointerClass(float32_t), PointerClass(type)], [PointerClass(float32_t)]),
424436
GatherTemplate.referenceTemplate, ForkTransformer) for type in IntegerDataTypes

Deeploy/Targets/PULPOpen/Platform.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,18 @@
1313
from Deeploy.MemoryLevelExtension.NetworkDeployers.MemoryLevelDeployer import MemoryPlatform, MemoryPlatformWrapper
1414
from Deeploy.Targets.Generic.Bindings import BasicGEMMBindings, BasicPad1DBindings, BasicPad2DBindings, \
1515
BasicRQIntegerDivBinding
16-
from Deeploy.Targets.Generic.Layers import AddLayer, ConcatLayer, ConvLayer, GatherLayer, GELULayer, GEMMLayer, \
17-
LayerNormLayer, MatMulLayer, MaxPoolLayer, MulLayer, PadLayer, QuantLayer, ReduceMeanLayer, ReduceSumLayer, \
18-
ReluLayer, RequantShiftLayer, ReshapeLayer, RQIntegerDivLayer, RQSiGELULayer, RQSiHardswishLayer, SGDLayer, \
19-
SliceLayer, SoftmaxCrossEntropyLossGradLayer, SoftmaxCrossEntropyLossLayer, SoftmaxGradLayer, SoftmaxLayer, \
20-
TransposeLayer, iHardswishLayer, iRMSNormLayer
16+
from Deeploy.Targets.Generic.Layers import AddLayer, ConcatLayer, ConvLayer, GatherLayer, GELUGradLayer, GELULayer, \
17+
GEMMLayer, LayerNormGradLayer, LayerNormLayer, MatMulLayer, MaxPoolLayer, MulLayer, PadLayer, QuantLayer, \
18+
ReduceMeanLayer, ReduceSumLayer, ReluLayer, RequantShiftLayer, ReshapeLayer, RQIntegerDivLayer, RQSiGELULayer, \
19+
RQSiHardswishLayer, SGDLayer, SliceLayer, SoftmaxCrossEntropyLossGradLayer, SoftmaxCrossEntropyLossLayer, \
20+
SoftmaxGradLayer, SoftmaxLayer, TransposeLayer, iHardswishLayer, iRMSNormLayer
2121
from Deeploy.Targets.Generic.Parsers import AddParser, ConcatParser, DequantParser, FlattenParser, GatherParser, \
22-
GELUParser, GEMMParser, LayerNormParser, MatMulParser, MaxPool2DParser, MulParser, Pad1DParser, Pad2DParser, \
23-
QuantParser, ReduceMeanParser, ReduceSumParser, ReluParser, RequantShiftParser, ReshapeParser, RQAddParser, \
24-
RQIntegerDivParser, RQSiGELUParser, RQSiHardswishParser, SGDParser, SliceParser, \
25-
SoftmaxCrossEntropyLossGradParser, SoftmaxCrossEntropyLossParser, SoftmaxGradParser, SoftmaxParser, \
26-
TransposeParser, UniformRequantShiftParser, UnsqueezeParser, iHardswishParser, iRMSNormParser, iSoftmaxParser
22+
GELUGradParser, GELUParser, GEMMParser, LayerNormGradParser, LayerNormParser, MatMulParser, MaxPool2DParser, \
23+
MulParser, Pad1DParser, Pad2DParser, QuantParser, ReduceMeanParser, ReduceSumParser, ReluParser, \
24+
RequantShiftParser, ReshapeParser, RQAddParser, RQIntegerDivParser, RQSiGELUParser, RQSiHardswishParser, \
25+
SGDParser, SliceParser, SoftmaxCrossEntropyLossGradParser, SoftmaxCrossEntropyLossParser, SoftmaxGradParser, \
26+
SoftmaxParser, TransposeParser, UniformRequantShiftParser, UnsqueezeParser, iHardswishParser, iRMSNormParser, \
27+
iSoftmaxParser
2728
from Deeploy.Targets.Generic.Templates import AllocateTemplate as BasicAllocateTemplate
2829
from Deeploy.Targets.Generic.TopologyOptimizationPasses.Passes import DequantPatternPass, IntegerDivRequantMergePass, \
2930
MergeConstAddAndRequantPass, MergeTrueIntegerDivRequantShiftPass, QuantPatternPass, RQSSplitPass, \
@@ -37,14 +38,15 @@
3738
from Deeploy.Targets.PULPOpen.Templates import AllocateTemplate, FreeTemplate
3839
from Deeploy.Targets.PULPOpen.Tiler import PULPAddTilingReadyBindings, PULPConcatTilingReadyBindings, \
3940
PULPConv2DTilingReadyBindings, PULPDWConv2DTilingReadyBindings, PULPFlattenTilingReadyBindings, \
40-
PULPFPGELUTilingReadyBindings, PULPFPGEMMTilingReadyBindings, PULPGatherTilingReadyBindings, \
41-
PULPiHardswishTilingReadyBindings, PULPiRMSNormTilingReadyBindings, PULPiRQSGELUTilingReadyBindings, \
42-
PULPLayernormTilingReadyBindings, PULPMatMulTilingReadyBindings, PULPMaxPool2DTilingReadyBindings, \
43-
PULPMulTilingReadyBindings, PULPReduceMeanTilingReadyBindings, PULPReduceSumTilingReadyBindings, \
44-
PULPReluTilingReadyBindings, PULPRQAddTilingReadyBindings, PULPRQSConv2DTilingReadyBindings, \
45-
PULPRQSDWConv2DTilingReadyBindings, PULPRQSGEMMTilingReadyBindings, PULPRQSiHardswishTilingReadyBindings, \
46-
PULPRQSMatrixVecTilingReadyBindings, PULPRQSTallGEMMTilingReadyBindings, PULPRQSTilingReadyBindings, \
47-
PULPSGDTilingReadyBindings, PULPSliceTilingReadyBindings, PULPSoftmaxCrossEntropyGradTilingReadyBindings, \
41+
PULPFPGELUGradTilingReadyBindings, PULPFPGELUTilingReadyBindings, PULPFPGEMMTilingReadyBindings, \
42+
PULPGatherTilingReadyBindings, PULPiHardswishTilingReadyBindings, PULPiRMSNormTilingReadyBindings, \
43+
PULPiRQSGELUTilingReadyBindings, PULPLayernormGradTilingReadyBindings, PULPLayernormTilingReadyBindings, \
44+
PULPMatMulTilingReadyBindings, PULPMaxPool2DTilingReadyBindings, PULPMulTilingReadyBindings, \
45+
PULPReduceMeanTilingReadyBindings, PULPReduceSumTilingReadyBindings, PULPReluTilingReadyBindings, \
46+
PULPRQAddTilingReadyBindings, PULPRQSConv2DTilingReadyBindings, PULPRQSDWConv2DTilingReadyBindings, \
47+
PULPRQSGEMMTilingReadyBindings, PULPRQSiHardswishTilingReadyBindings, PULPRQSMatrixVecTilingReadyBindings, \
48+
PULPRQSTallGEMMTilingReadyBindings, PULPRQSTilingReadyBindings, PULPSGDTilingReadyBindings, \
49+
PULPSliceTilingReadyBindings, PULPSoftmaxCrossEntropyGradTilingReadyBindings, \
4850
PULPSoftmaxCrossEntropyTilingReadyBindings, PULPSoftmaxGradTilingReadyBindings, PULPSoftmaxTilingReadyBindings, \
4951
PULPTransposeTilingReadyBindings, PULPUniformRQSTilingReadyBindings
5052
from Deeploy.Targets.PULPOpen.TopologyOptimizationPasses.Passes import PULPAddRequantMergePass, \
@@ -54,6 +56,7 @@
5456
AddMapper = NodeMapper(AddParser(), PULPAddTilingReadyBindings)
5557
FlattenMapper = NodeMapper(FlattenParser(), PULPFlattenTilingReadyBindings)
5658
GELUMapper = NodeMapper(GELUParser(), PULPFPGELUTilingReadyBindings)
59+
GELUGradMapper = NodeMapper(GELUGradParser(), PULPFPGELUGradTilingReadyBindings)
5760
GatherMapper = NodeMapper(GatherParser(), PULPGatherTilingReadyBindings)
5861
MulMapper = NodeMapper(MulParser(), PULPMulTilingReadyBindings)
5962
Pad1DMapper = NodeMapper(Pad1DParser(), BasicPad1DBindings)
@@ -83,6 +86,7 @@
8386
TallGEMMMapper = NodeMapper(PULPTallGEMMParser(), PULPRQSTallGEMMTilingReadyBindings)
8487
MaxPool2DMapper = NodeMapper(MaxPool2DParser(), PULPMaxPool2DTilingReadyBindings)
8588
LayerNormMapper = NodeMapper(LayerNormParser(), PULPLayernormTilingReadyBindings)
89+
LayerNormGradMapper = NodeMapper(LayerNormGradParser(), PULPLayernormGradTilingReadyBindings)
8690
ReluMapper = NodeMapper(ReluParser(), PULPReluTilingReadyBindings)
8791
SoftmaxMapper = NodeMapper(SoftmaxParser(), PULPSoftmaxTilingReadyBindings)
8892
SoftmaxGradMapper = NodeMapper(SoftmaxGradParser(), PULPSoftmaxGradTilingReadyBindings)
@@ -111,7 +115,9 @@
111115
'RequantizedGemm': PULPRQSGEMMLayer([MatrixVecMapper, TallGEMMMapper, GEMMMapper]),
112116
'Gemm': GEMMLayer([FloatGEMMMapper, GEMMDequantMapper]),
113117
'Gelu': GELULayer([GELUMapper]),
118+
'GeluGrad': GELUGradLayer([GELUGradMapper]),
114119
'LayerNormalization': LayerNormLayer([LayerNormMapper]),
120+
'LayerNormalizationGrad': LayerNormGradLayer([LayerNormGradMapper]),
115121
'MaxPool': MaxPoolLayer([MaxPool2DMapper]),
116122
'RequantizediGELU': RQSiGELULayer([RQGELU_int8_Mapper]),
117123
'RQIntegerDiv': RQIntegerDivLayer([RQIntegerDivMapper]),

Deeploy/Targets/PULPOpen/Templates/FloatGELUTemplate.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,14 @@
77
referenceTemplate = NodeTemplate("""
88
// GELU (Name: ${nodeName}, Op: ${nodeOp})
99
PULP_GELU_fp${data_in_type.referencedType.typeWidth}_fp${data_out_type.referencedType.typeWidth}(${data_in}, ${data_out}, ${size});
10+
""")
11+
12+
referenceGradTemplate = NodeTemplate("""
13+
// GELU Parallel (Name: ${nodeName}, Op: ${nodeOp})
14+
int8_t ${nodeName}_core_id = pi_core_id();
15+
int8_t ${nodeName}_log2Core = log2(NUM_CORES);
16+
int16_t ${nodeName}_chunk = (${size} >> ${nodeName}_log2Core) + ((${size} & (NUM_CORES-1))!=0);
17+
int16_t ${nodeName}_chunk_start = MIN(${nodeName}_chunk*${nodeName}_core_id, ${size});
18+
int16_t ${nodeName}_chunk_stop = MIN(${nodeName}_chunk_start + ${nodeName}_chunk, ${size});
19+
GELU_fp${data_in_type.referencedType.typeWidth}_fp${grad_out_type.referencedType.typeWidth}_sigmoid_grad_chunk(${grad_in}, ${data_in}, ${grad_out}, ${nodeName}_chunk_start, ${nodeName}_chunk_stop);
1020
""")

Deeploy/Targets/PULPOpen/Templates/FloatGemmTemplate.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,42 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
from Deeploy.DeeployTypes import NodeTemplate
5+
from typing import Dict, List, Tuple
66

7-
referenceTemplate = NodeTemplate("""
7+
from Deeploy.AbstractDataTypes import float32_tPtr
8+
from Deeploy.DeeployTypes import NetworkContext, NodeTemplate, OperatorRepresentation
9+
10+
11+
class PULPFloatGEMMTemplate(NodeTemplate):
12+
13+
def __init__(self, templateStr):
14+
super().__init__(templateStr)
15+
16+
def alignToContext(self, ctxt: NetworkContext,
17+
operatorRepresentation: OperatorRepresentation) -> Tuple[NetworkContext, Dict, List[str]]:
18+
19+
if 'C' not in operatorRepresentation or operatorRepresentation['C'] is None:
20+
# No bias case - set C to NULL and provide a default type
21+
operatorRepresentation['C'] = None
22+
operatorRepresentation['C_type'] = float32_tPtr # Default to fp32 type
23+
operatorRepresentation['C_batched'] = False
24+
25+
return ctxt, operatorRepresentation, []
26+
27+
28+
referenceTemplate = PULPFloatGEMMTemplate("""
829
// GEMM (Name: ${nodeName}, Op: ${nodeOp})
930
${A_type.typeName} ref_${data_out}_${A} = ${A};
1031
${B_type.typeName} ref_${data_out}_${B} = ${B};
32+
% if C is not None:
1133
${C_type.typeName} ref_${data_out}_${C} = ${C};
34+
% else:
35+
${C_type.typeName} ref_${data_out}_C = NULL;
36+
% endif
1237
${data_out_type.typeName} ref_${data_out}_${data_out} = ${data_out};
1338
1439
for(uint32_t i=0; i<${batch}; i++){
40+
% if C is not None:
1541
PULP_Gemm_fp${A_type.referencedType.typeWidth}_fp${B_type.referencedType.typeWidth}_fp${C_type.referencedType.typeWidth}_fp${data_out_type.referencedType.typeWidth}(
1642
ref_${data_out}_${A},
1743
ref_${data_out}_${B},
@@ -23,7 +49,19 @@
2349
${transA},
2450
${transB}
2551
);
26-
52+
% else:
53+
PULP_Gemm_fp${A_type.referencedType.typeWidth}_fp${B_type.referencedType.typeWidth}_fp${C_type.referencedType.typeWidth}_fp${data_out_type.referencedType.typeWidth}(
54+
ref_${data_out}_${A},
55+
ref_${data_out}_${B},
56+
NULL,
57+
ref_${data_out}_${data_out},
58+
${M},
59+
${N},
60+
${O},
61+
${transA},
62+
${transB}
63+
);
64+
% endif
2765
% if A_batched:
2866
ref_${data_out}_${A} += ${M} * ${N};
2967
% endif
@@ -32,7 +70,7 @@
3270
ref_${data_out}_${B} += ${N} * ${O};
3371
% endif
3472
35-
% if C_batched:
73+
% if C is not None and C_batched:
3674
ref_${data_out}_${C} += ${M} * ${O};
3775
% endif
3876

Deeploy/Targets/PULPOpen/Templates/FloatLayernormTemplate.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,38 @@
1515
${size},
1616
${lastDimLength}
1717
);
18+
""")
19+
20+
referenceGradTemplate = NodeTemplate("""
21+
// FloatLayernormGrad Parallel (Name: ${nodeName}, Op: ${nodeOp})
22+
23+
int8_t ${nodeName}_core_id = pi_core_id();
24+
int8_t ${nodeName}_log2Core = log2(NUM_CORES);
25+
26+
int32_t ${nodeName}_seq_length = ${size} / ${lastDimLength};
27+
int32_t ${nodeName}_chunk = (${nodeName}_seq_length >> ${nodeName}_log2Core) +
28+
((${nodeName}_seq_length & (NUM_CORES-1)) != 0);
29+
int32_t ${nodeName}_start = MIN(${nodeName}_chunk * ${nodeName}_core_id, ${nodeName}_seq_length);
30+
int32_t ${nodeName}_end = MIN(${nodeName}_start + ${nodeName}_chunk, ${nodeName}_seq_length);
31+
32+
int32_t ${nodeName}_elem_start = ${nodeName}_start * ${lastDimLength};
33+
int32_t ${nodeName}_elem_end = ${nodeName}_end * ${lastDimLength};
34+
int32_t ${nodeName}_elem_count = ${nodeName}_elem_end - ${nodeName}_elem_start;
35+
36+
const float${grad_in_type.referencedType.typeWidth}_t* ${nodeName}_grad_in_ptr = ${grad_in} + ${nodeName}_elem_start;
37+
const float${data_in_type.referencedType.typeWidth}_t* ${nodeName}_data_in_ptr = ${data_in} + ${nodeName}_elem_start;
38+
float${grad_out_type.referencedType.typeWidth}_t* ${nodeName}_grad_out_ptr = ${grad_out} + ${nodeName}_elem_start;
39+
40+
if (${nodeName}_elem_count > 0) {
41+
LayernormGrad_fp${grad_in_type.referencedType.typeWidth}_fp${grad_out_type.referencedType.typeWidth}(
42+
${nodeName}_grad_in_ptr, // Upstream gradient (dy)
43+
${nodeName}_data_in_ptr, // Original input (x)
44+
${nodeName}_grad_out_ptr, // Output gradient (dx)
45+
${weight}, // Input Scale parameter
46+
${bias}, // Input Bias parameter
47+
${epsilon}, // Epsilon for numerical stability
48+
${nodeName}_elem_count, // Number of elements to process
49+
${lastDimLength} // Size of the feature dimension
50+
);
51+
}
1852
""")

0 commit comments

Comments
 (0)