Skip to content

Commit ecaa4cd

Browse files
committed
Add Gemm and RQGemm
1 parent b10c0a5 commit ecaa4cd

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

Deeploy/OperatorDescriptor.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def IntUnpack(value: Any) -> int:
2525

2626
def BoolUnpack(value: Any) -> bool:
2727
value = IntUnpack(value)
28+
assert value in [0, 1], f"Casting to bool only supported from 0, 1. Received {value}"
2829
return bool(value)
2930

3031

@@ -483,6 +484,31 @@ def canonicalize(self, node: gs.Node, opset: int) -> bool:
483484
],
484485
)
485486

487+
gemmDesc = OperatorDescriptor(
488+
inputDescriptor = IoDesc(["A", "B"], optional = ["C"]),
489+
outputDescriptor = IoDesc("data_out"),
490+
attrDescriptors = [
491+
AttrDesc("alpha", FloatUnpack, default = 1.0),
492+
AttrDesc("beta", FloatUnpack, default = 1.0),
493+
AttrDesc("transA", BoolUnpack, default = False),
494+
AttrDesc("transB", BoolUnpack, default = False),
495+
],
496+
)
497+
498+
rqGemmDesc = RequantizedOperatorDescriptor(
499+
inputDescriptor = IoDesc(["A", "B", "C", "add", "mul"]),
500+
outputDescriptor = IoDesc("data_out"),
501+
attrDescriptors = [
502+
AttrDesc("alpha", FloatUnpack, default = 1.0),
503+
AttrDesc("beta", FloatUnpack, default = 1.0),
504+
AttrDesc("transA", BoolUnpack, default = False),
505+
AttrDesc("transB", BoolUnpack, default = False),
506+
# RequantizedShift attrs
507+
AttrDesc("n_levels", IntUnpack),
508+
AttrDesc("signed", BoolUnpack),
509+
AttrDesc("div", IntUnpack),
510+
])
511+
486512
defaultOperatorDescriptors: Dict[str, OperatorDescriptor] = {
487513
"Add": addDesc,
488514
"Concat": concatDesc,

0 commit comments

Comments
 (0)