@@ -542,8 +542,31 @@ def canonicalize(self, node: gs.Node, opset: int) -> bool:
542542 AttrDesc ("heads" , IntUnpack ),
543543 ])
544544
545+ clcaDesc = OperatorDescriptor (
546+ inputDescriptor = IoDesc ([
547+ "q" , "k" , "wq_weight" , "wq_bias" , "wk_weight" , "wk_bias" , "wo_weight" , "wo_bias" , "wq_requant_mul" ,
548+ "wq_requant_add" , "wq_requant_div" , "wk_requant_mul" , "wk_requant_add" , "wk_requant_div" , "wv_requant_mul" ,
549+ "wv_requant_add" , "wv_requant_div" , "kdiv_requant_mul" , "kdiv_requant_add" , "kdiv_requant_div" ,
550+ "preattn_requant_mul" , "preattn_requant_add" , "preattn_requant_div" , "postattn_requant_mul" ,
551+ "postattn_requant_add" , "postattn_requant_div" , "wo_requant_mul" , "wo_requant_add" , "wo_requant_div"
552+ ]),
553+ outputDescriptor = IoDesc ("data_out" ),
554+ attrDescriptors = [
555+ AttrDesc ("Delta" , IntUnpack ),
556+ AttrDesc ("eps" , IntUnpack ),
557+ AttrDesc ("eta" , IntUnpack ),
558+ AttrDesc ("act_type" , IntUnpack ),
559+ AttrDesc ("n_levels" , IntUnpack ),
560+ AttrDesc ("dim" , IntUnpack ),
561+ AttrDesc ("dim_head" , IntUnpack ),
562+ AttrDesc ("out_dim" , IntUnpack ),
563+ AttrDesc ("heads" , IntUnpack ),
564+ ],
565+ )
566+
545567defaultOperatorDescriptors : Dict [str , OperatorDescriptor ] = {
546568 "Add" : addDesc ,
569+ "CLCA" : clcaDesc ,
547570 "Concat" : concatDesc ,
548571 "Conv" : convDesc ,
549572 "DebugPrint" : debugPrintDesc ,
0 commit comments