Skip to content

Commit 9cea130

Browse files
author
yarden-sony
committed
box decode tpc
1 parent 0963d1d commit 9cea130

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

model_compression_toolkit/target_platform_capabilities/schema/v1.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ class OperatorSetNames(str, Enum):
7474
TOPK = "TopK"
7575
FAKE_QUANT = "FakeQuant"
7676
COMBINED_NON_MAX_SUPPRESSION = "CombinedNonMaxSuppression"
77+
BOX_DECODE = "BoxDecode"
7778
ZERO_PADDING2D = "ZeroPadding2D"
7879
CAST = "Cast"
7980
RESIZE = "Resize"

model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@
3232
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2fw import \
3333
AttachTpcToFramework
3434
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attribute_filter import Eq
35-
from sony_custom_layers.pytorch import MulticlassNMS, MulticlassNMSWithIndices, multiclass_nms_with_indices
35+
from sony_custom_layers.pytorch import MulticlassNMS, MulticlassNMSWithIndices, multiclass_nms_with_indices, \
36+
FasterRCNNBoxDecode, multiclass_nms
3637

3738

3839
class AttachTpcToPytorch(AttachTpcToFramework):
@@ -98,7 +99,9 @@ def __init__(self):
9899
OperatorSetNames.L2NORM: [LayerFilterParams(torch.nn.functional.normalize,
99100
Eq('p', 2) | Eq('p', None))],
100101
OperatorSetNames.SSD_POST_PROCESS: [], # no such operator in pytorch
101-
OperatorSetNames.COMBINED_NON_MAX_SUPPRESSION: [multiclass_nms_with_indices, MulticlassNMS, MulticlassNMSWithIndices] # no such operator in pytorch
102+
OperatorSetNames.COMBINED_NON_MAX_SUPPRESSION: [multiclass_nms, multiclass_nms_with_indices, MulticlassNMS,
103+
MulticlassNMSWithIndices], # no such operator in pytorch
104+
OperatorSetNames.BOX_DECODE: [FasterRCNNBoxDecode]
102105
}
103106

104107
pytorch_linear_attr_mapping = {KERNEL_ATTR: DefaultDict(default_value=PYTORCH_KERNEL),

0 commit comments

Comments
 (0)