Skip to content

Commit 19992ec

Browse files
committed
Add CLCA
1 parent d256432 commit 19992ec

File tree

2 files changed

+24
-9
lines changed

2 files changed

+24
-9
lines changed

Deeploy/OperatorDescriptor.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,8 +542,31 @@ def canonicalize(self, node: gs.Node, opset: int) -> bool:
542542
AttrDesc("heads", IntUnpack),
543543
])
544544

545+
clcaDesc = OperatorDescriptor(
546+
inputDescriptor = IoDesc([
547+
"q", "k", "wq_weight", "wq_bias", "wk_weight", "wk_bias", "wo_weight", "wo_bias", "wq_requant_mul",
548+
"wq_requant_add", "wq_requant_div", "wk_requant_mul", "wk_requant_add", "wk_requant_div", "wv_requant_mul",
549+
"wv_requant_add", "wv_requant_div", "kdiv_requant_mul", "kdiv_requant_add", "kdiv_requant_div",
550+
"preattn_requant_mul", "preattn_requant_add", "preattn_requant_div", "postattn_requant_mul",
551+
"postattn_requant_add", "postattn_requant_div", "wo_requant_mul", "wo_requant_add", "wo_requant_div"
552+
]),
553+
outputDescriptor = IoDesc("data_out"),
554+
attrDescriptors = [
555+
AttrDesc("Delta", IntUnpack),
556+
AttrDesc("eps", IntUnpack),
557+
AttrDesc("eta", IntUnpack),
558+
AttrDesc("act_type", IntUnpack),
559+
AttrDesc("n_levels", IntUnpack),
560+
AttrDesc("dim", IntUnpack),
561+
AttrDesc("dim_head", IntUnpack),
562+
AttrDesc("out_dim", IntUnpack),
563+
AttrDesc("heads", IntUnpack),
564+
],
565+
)
566+
545567
defaultOperatorDescriptors: Dict[str, OperatorDescriptor] = {
546568
"Add": addDesc,
569+
"CLCA": clcaDesc,
547570
"Concat": concatDesc,
548571
"Conv": convDesc,
549572
"DebugPrint": debugPrintDesc,

Deeploy/Targets/Generic/Parsers.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1483,15 +1483,7 @@ def parseNode(self, node: gs.Node) -> (bool):
14831483
])
14841484

14851485
if ret:
1486-
self.operatorRepresentation['Delta'] = int(node.attrs['Delta'])
1487-
self.operatorRepresentation['eps'] = int(node.attrs['eps'])
1488-
self.operatorRepresentation['eta'] = int(node.attrs['eta'])
1489-
self.operatorRepresentation['act_type'] = int(node.attrs['act_type'])
1490-
self.operatorRepresentation['n_levels'] = int(node.attrs['n_levels'].values)
1491-
self.operatorRepresentation['dim'] = int(node.attrs['dim'].values)
1492-
self.operatorRepresentation['dim_head'] = int(node.attrs['dim_head'].values)
1493-
self.operatorRepresentation['out_dim'] = int(node.attrs['out_dim'].values)
1494-
self.operatorRepresentation['heads'] = int(node.attrs['heads'].values)
1486+
self.operatorRepresentation.update(node.attrs)
14951487

14961488
return ret
14971489

0 commit comments

Comments
 (0)