Skip to content

Commit c7b9771

Browse files
committed
refactor: address reviewer comments and reduce code duplication
- CI: split snitch-kernels and snitch-models into separate jobs (matching Siracusa pattern) - Revert out-of-scope LoweringOptimizationPasses.py change - Merge duplicate FloatDivTileConstraint/FloatMulTileConstraint into FloatScalarBOPTileConstraint - Consolidate TileConstraint imports in Tiler.py - Remove unused imports (NodeParser in Parsers.py, BasicConcatBindings in Platform.py)
1 parent 6b357cf commit c7b9771

File tree

8 files changed

+27
-114
lines changed

8 files changed

+27
-114
lines changed

.github/workflows/ci-platform-snitch.yml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,12 @@ jobs:
3535
with:
3636
runner: ${{ needs.select-env.outputs.runner }}
3737
docker-image: ${{ needs.select-env.outputs.image }}
38-
pytest-marker: "(kernels or models)"
38+
pytest-marker: kernels
39+
40+
snitch-models:
41+
needs: select-env
42+
uses: ./.github/workflows/_runner-snitch.yml
43+
with:
44+
runner: ${{ needs.select-env.outputs.runner }}
45+
docker-image: ${{ needs.select-env.outputs.image }}
46+
pytest-marker: models

Deeploy/CommonExtensions/OptimizationPasses/TopologyOptimizationPasses/LoweringOptimizationPasses.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -488,9 +488,7 @@ def _remove_global_output_reshape_fun(graph: gs.Graph, match: Match, name: str):
488488
node = next(iter((match.nodes_map.values())))
489489

490490
isGlobalOutput = len(node.outputs[0].outputs) == 0
491-
# Don't delete if the input is also a global input (i.e., single-node graph)
492-
isGlobalInput = node.inputs[0] in graph.inputs
493-
if isGlobalOutput and not isGlobalInput:
491+
if isGlobalOutput:
494492
graph.deleteNode(node)
495493

496494
return graph

Deeploy/Targets/Snitch/Parsers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy as np
88
import onnx_graphsurgeon as gs
99

10-
from Deeploy.DeeployTypes import NetworkContext, NodeParser
10+
from Deeploy.DeeployTypes import NetworkContext
1111
from Deeploy.Targets.Generic.Parsers import AddParser, DivParser, GEMMParser, MulParser, RQGEMMParser, \
1212
iHardswishParser, iRMSNormParser
1313

Deeploy/Targets/Snitch/Platform.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
from Deeploy.AbstractDataTypes import Pointer, PointerClass, VoidType
1010
from Deeploy.DeeployTypes import ConstantBuffer, DeploymentEngine, DeploymentPlatform, NodeMapper, NodeTemplate, \
1111
StructBuffer, TopologyOptimizer, TransientBuffer, VariableBuffer
12-
from Deeploy.Targets.Generic.Bindings import BasicConcatBindings, BasicLayerNormBindings, BasicPad1DBindings, \
13-
BasicPad2DBindings, BasicReshapeBindings, BasicRQIntegerDivBinding
12+
from Deeploy.Targets.Generic.Bindings import BasicLayerNormBindings, BasicPad1DBindings, BasicPad2DBindings, \
13+
BasicReshapeBindings, BasicRQIntegerDivBinding
1414
from Deeploy.Targets.Generic.Layers import AddLayer, ConcatLayer, DivLayer, GatherLayer, GEMMLayer, HardSwishLayer, \
1515
LayerNormLayer, MatMulLayer, MulLayer, PadLayer, ReshapeLayer, RMSNormLayer, RQGEMMLayer, RQIntegerDivLayer, \
1616
SoftmaxLayer, TransposeLayer, iNoNormLayer

Deeploy/Targets/Snitch/TileConstraints/FloatMulTileConstraint.py

Lines changed: 0 additions & 96 deletions
This file was deleted.

Deeploy/Targets/Snitch/TileConstraints/FloatDivTileConstraint.py renamed to Deeploy/Targets/Snitch/TileConstraints/FloatScalarBOPTileConstraint.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,13 @@
1616
VariableReplacementScheme
1717

1818

19-
class FloatDivTileConstraint(TileConstraint):
20-
"""Tile constraint for FP32 Div: supports scalar and element-wise cases."""
19+
class FloatScalarBOPTileConstraint(TileConstraint):
20+
"""Tile constraint for binary operators with scalar broadcasting support.
21+
22+
Extends BOPTileConstraint with scalar handling: when one input has size 1,
23+
it is loaded in full (not tiled) while the other input and output are tiled together.
24+
Used by FP32 Div and Mul operators.
25+
"""
2126

2227
dataIn1Name = "A"
2328
dataIn2Name = "B"

Deeploy/Targets/Snitch/TileConstraints/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
from . import *
6-
from .FloatDivTileConstraint import *
7-
from .FloatMulTileConstraint import *
6+
from .FloatScalarBOPTileConstraint import *
7+
from .GemmTileConstraint import *
88
from .iNoNormTileConstraint import *
99
from .iSoftmaxTileConstraint import *
10+
from .RqGemmTileConstraint import *

Deeploy/Targets/Snitch/Tiler.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,8 @@
1414
SnitchGatherBindings, SnitchGemmBindings, SnitchHardSwishBindings, SnitchiNoNormBindings, SnitchiSoftmaxBindings, \
1515
SnitchMatMulBindings, SnitchMulBindings, SnitchReshapeBindings, SnitchRMSNormBindings, SnitchRQAddBindings, \
1616
SnitchRqGemmBindings, SnitchTransposeBindings
17-
from Deeploy.Targets.Snitch.TileConstraints import iNoNormTileConstraint, iSoftmaxTileConstraint
18-
from Deeploy.Targets.Snitch.TileConstraints.FloatDivTileConstraint import FloatDivTileConstraint
19-
from Deeploy.Targets.Snitch.TileConstraints.FloatMulTileConstraint import FloatMulTileConstraint
20-
from Deeploy.Targets.Snitch.TileConstraints.GemmTileConstraint import GemmTileConstraint
21-
from Deeploy.Targets.Snitch.TileConstraints.RqGemmTileConstraint import RqGemmTileConstraint
17+
from Deeploy.Targets.Snitch.TileConstraints import FloatScalarBOPTileConstraint, GemmTileConstraint, \
18+
iNoNormTileConstraint, iSoftmaxTileConstraint, RqGemmTileConstraint
2219
from Deeploy.TilingExtension.TilerExtension import TilingReadyNodeBindings
2320

2421
SnitchiSoftmaxTilingReadyBindings = TilingReadyNodeBindings(nodeBindings = SnitchiSoftmaxBindings,
@@ -42,10 +39,10 @@
4239
tileConstraint = iHardswishTileConstraint())
4340

4441
SnitchDivTilingReadyBindings = TilingReadyNodeBindings(nodeBindings = SnitchDivBindings,
45-
tileConstraint = FloatDivTileConstraint())
42+
tileConstraint = FloatScalarBOPTileConstraint())
4643

4744
SnitchMulTilingReadyBindings = TilingReadyNodeBindings(nodeBindings = SnitchMulBindings,
48-
tileConstraint = FloatMulTileConstraint())
45+
tileConstraint = FloatScalarBOPTileConstraint())
4946

5047
SnitchMatMulTilingReadyBindings = TilingReadyNodeBindings(nodeBindings = SnitchMatMulBindings,
5148
tileConstraint = MatMulTileConstraint())

0 commit comments

Comments
 (0)