@@ -25,6 +25,7 @@ def IntUnpack(value: Any) -> int:
2525
2626def BoolUnpack (value : Any ) -> bool :
2727 value = IntUnpack (value )
28+ assert value in [0 , 1 ], f"Casting to bool only supported from 0, 1. Received { value } "
2829 return bool (value )
2930
3031
@@ -483,6 +484,31 @@ def canonicalize(self, node: gs.Node, opset: int) -> bool:
483484 ],
484485)
485486
487+ gemmDesc = OperatorDescriptor (
488+ inputDescriptor = IoDesc (["A" , "B" ], optional = ["C" ]),
489+ outputDescriptor = IoDesc ("data_out" ),
490+ attrDescriptors = [
491+ AttrDesc ("alpha" , FloatUnpack , default = 1.0 ),
492+ AttrDesc ("beta" , FloatUnpack , default = 1.0 ),
493+ AttrDesc ("transA" , BoolUnpack , default = False ),
494+ AttrDesc ("transB" , BoolUnpack , default = False ),
495+ ],
496+ )
497+
498+ rqGemmDesc = RequantizedOperatorDescriptor (
499+ inputDescriptor = IoDesc (["A" , "B" , "C" , "add" , "mul" ]),
500+ outputDescriptor = IoDesc ("data_out" ),
501+ attrDescriptors = [
502+ AttrDesc ("alpha" , FloatUnpack , default = 1.0 ),
503+ AttrDesc ("beta" , FloatUnpack , default = 1.0 ),
504+ AttrDesc ("transA" , BoolUnpack , default = False ),
505+ AttrDesc ("transB" , BoolUnpack , default = False ),
506+ # RequantizedShift attrs
507+ AttrDesc ("n_levels" , IntUnpack ),
508+ AttrDesc ("signed" , BoolUnpack ),
509+ AttrDesc ("div" , IntUnpack ),
510+ ])
511+
486512defaultOperatorDescriptors : Dict [str , OperatorDescriptor ] = {
487513 "Add" : addDesc ,
488514 "Concat" : concatDesc ,
0 commit comments