@@ -143,8 +143,19 @@ def get_quantization_params_from_linear_node(linear_op: torch.fx.node.Node):
143143 return input_params , weight_params , output_params
144144
145145
146- def extract_weight_name (node : Node ) -> str :
146+ def get_all_weights_in_subgraph (
147+ sources : list [Node ],
148+ sinks : list [Node ],
149+ ):
150+ """Get all weight nodes (get_attr nodes) in the subgraph between sources and sinks."""
151+ weight_nodes = subgraph (sources , sinks , include = is_weight_node )
152+ return weight_nodes
153+
154+
155+ def extract_weight_name (node : Node ) -> Union [str , bool ]:
147156 weight_nodes = extract_weight_nodes (node )
157+ if len (weight_nodes .weights ) == 0 :
158+ return False
148159 return weight_nodes .weights [0 ].node_key
149160
150161
@@ -431,6 +442,10 @@ def is_dist_op(node: Node) -> bool:
431442 return is_op (node , dist_ops )
432443
433444
445+ def is_weight_node (node : Node ) -> bool :
446+ return node .op == "get_attr" and node .target and has_shape (node ) and len (shape (node )) > 0
447+
448+
434449def get_user_if_pattern_match (node , ops , numusers , user_idx : int = 0 ):
435450 """Get a user from a node if the node matches a given op set and num of users."""
436451 if node is None :
@@ -531,6 +546,9 @@ def get_all_layer_subgraphs(gm: GraphModule) -> List[List[Node]]:
531546 assert gm .graph .nodes , "Graph is empty"
532547 layer_subgraphs = []
533548 linear_nodes = list (filtered_nodes (gm .graph .nodes , is_any_lin_op ))
549+
550+ # find the embedding size of this model. Extract it from the input of the first linear node.
551+ embd = get_weight_shape (linear_nodes [0 ], dim = - 1 )
534552 unprocessed_linear_nodes = set (linear_nodes )
535553 assert len (linear_nodes ) > 0 , "Could not find any linear nodes in the graph"
536554
@@ -542,7 +560,7 @@ def get_all_layer_subgraphs(gm: GraphModule) -> List[List[Node]]:
542560 # opening is the list of linear nodes
543561 # layer_subgraph is the list of nodes between the opening and closing linear nodes
544562 # closing is the last linear node in the layer
545- layer_subgraph = get_layer_after_linear_node (linear_nodes , terminating_indices )
563+ layer_subgraph = get_layer_after_linear_node (linear_nodes , terminating_indices , embd = embd )
546564 if layer_subgraph .opening_nodes is not None and len (layer_subgraph .opening_nodes ) > 0 :
547565 unprocessed_linear_nodes -= (
548566 set (layer_subgraph .opening_nodes )
@@ -808,6 +826,7 @@ def get_weight_shape(node: Node, dim: Optional[int] = None) -> Optional[Union[in
808826def get_layer_after_linear_node (
809827 linear_nodes : List [Node ],
810828 terminating_indices : List [int ],
829+ embd : Optional [int ] = None ,
811830 match_on_shapes : bool = True ,
812831 enforce_strict_linear_history : bool = True ,
813832) -> LayerSubgraph :
@@ -882,8 +901,9 @@ def filter_condition(node: Node, embd: Optional[int] = None, dim: Optional[int]
882901 layer_type = LayerType .UNKNOWN ,
883902 )
884903 if match_on_shapes :
885- # get embedding size of the opening linear node
886- embd = get_weight_shape (linear_nodes [start_lin_index ], dim = - 1 )
904+ if embd is None :
905+ # get embedding size of the opening linear node
906+ embd = get_weight_shape (linear_nodes [start_lin_index ], dim = - 1 )
887907 # partial init boundary_condition and filter_condition
888908 boundary_condition = partial (boundary_condition , embd = embd , dim = 0 )
889909 filter_condition = partial (filter_condition , embd = embd , dim = 0 )
@@ -892,6 +912,18 @@ def filter_condition(node: Node, embd: Optional[int] = None, dim: Optional[int]
892912 sources = [linear_nodes [start_lin_index ]], boundary_condition = boundary_condition
893913 )
894914 lin_nodes_in_subgraph = list (filtered_nodes (forward_subgraph , filter_condition ))
915+ if len (lin_nodes_in_subgraph ) > 1 :
916+ # it means that probably we went over the boundary of the layer.
917+ # It may happen e.g., with MoLE (latent MoE), with the closing latent fc2 projection,
918+ # when the subgraph spanned over fc2 "spills" over consecutive layers.
919+ # Then, wrap this single linear node in LayerType.UNKNOWN and return.
920+ terminating_indices .append (start_lin_index )
921+ return LayerSubgraph (
922+ opening_nodes = [linear_nodes [start_lin_index ]],
923+ subgraph_nodes = [],
924+ terminating_node = linear_nodes [start_lin_index ],
925+ layer_type = LayerType .UNKNOWN ,
926+ )
895927 start_lin_index += 1
896928 start_lin_index -= 1
897929 terminating_linear_node = lin_nodes_in_subgraph [0 ]
@@ -924,25 +956,39 @@ def filter_condition(node: Node, embd: Optional[int] = None, dim: Optional[int]
924956 ssm_nodes = list (filtered_nodes (interior_nodes , is_any_ssm_op ))
925957 attention_nodes = list (filtered_nodes (interior_nodes , is_any_attention_op ))
926958 intermediate_lin_nodes = list (filtered_nodes (interior_nodes , is_any_lin_op ))
959+ intermediate_weight_nodes = list (
960+ filtered_nodes (
961+ interior_nodes , lambda n : is_weight_node (n ) and not is_any_lin_op (list (n .users )[0 ])
962+ )
963+ )
927964
928965 layer_type = LayerType .MLP
929966 min_local_shape = 1
930967 if len (ssm_nodes ) > 0 :
931- assert len (ssm_nodes ) == 1 , "SSM layer must have exactly one SSM node"
932- layer_type = LayerType .SSM
933- # determine head size
934- min_local_shape = shape (ssm_nodes [0 ])[- 1 ]
935- if len (attention_nodes ) > 0 :
936- assert len (attention_nodes ) == 1 , "Attention layer must have exactly one attention node"
937- layer_type = LayerType .ATTENTION
938- # determine head size
939- min_local_shape = shape (attention_nodes [0 ])[- 1 ]
940- if len (intermediate_lin_nodes ) > 0 :
941- assert len (intermediate_lin_nodes ) == 2 , (
942- "MLA layer must have exactly two intermediate linear nodes"
943- )
944- assert len (attention_nodes ) == 1 , "MLA layer must have exactly one attention node"
945- layer_type = LayerType .MLA
968+ if len (ssm_nodes ) == 1 :
969+ layer_type = LayerType .SSM
970+ # determine head size
971+ min_local_shape = shape (ssm_nodes [0 ])[- 1 ]
972+ else :
973+ layer_type = LayerType .UNKNOWN
974+ if len (attention_nodes ) > 0 and layer_type != LayerType .UNKNOWN :
975+ if len (attention_nodes ) == 1 :
976+ layer_type = LayerType .ATTENTION
977+ # determine head size
978+ min_local_shape = shape (attention_nodes [0 ])[- 1 ]
979+ else :
980+ layer_type = LayerType .UNKNOWN
981+ if len (intermediate_lin_nodes ) > 0 and layer_type != LayerType .UNKNOWN :
982+ if len (intermediate_lin_nodes ) == 2 and len (attention_nodes ) == 1 :
983+ layer_type = LayerType .MLA
984+ else :
985+ layer_type = LayerType .UNKNOWN
986+ # only SSM or MLA layers can have weight nodes in the interior nodes
987+ # TODO: Minimax does have RMSNorm inside attention, we need to
988+ # support it in the future.
989+ if len (intermediate_weight_nodes ) > 0 :
990+ if layer_type not in [LayerType .SSM , LayerType .MLA ]:
991+ layer_type = LayerType .UNKNOWN
946992
947993 layer_subgraph = LayerSubgraph (
948994 opening_nodes = opening_linear_nodes ,
0 commit comments