@@ -546,6 +546,9 @@ def get_all_layer_subgraphs(gm: GraphModule) -> List[List[Node]]:
546546 assert gm .graph .nodes , "Graph is empty"
547547 layer_subgraphs = []
548548 linear_nodes = list (filtered_nodes (gm .graph .nodes , is_any_lin_op ))
549+
550+ # find the embedding size of this model. Extract it from the input of the first linear node.
551+ embd = get_weight_shape (linear_nodes [0 ], dim = - 1 )
549552 unprocessed_linear_nodes = set (linear_nodes )
550553 assert len (linear_nodes ) > 0 , "Could not find any linear nodes in the graph"
551554
@@ -557,7 +560,7 @@ def get_all_layer_subgraphs(gm: GraphModule) -> List[List[Node]]:
557560 # opening is the list of linear nodes
558561 # layer_subgraph is the list of nodes between the opening and closing linear nodes
559562 # closing is the last linear node in the layer
560- layer_subgraph = get_layer_after_linear_node (linear_nodes , terminating_indices )
563+ layer_subgraph = get_layer_after_linear_node (linear_nodes , terminating_indices , embd = embd )
561564 if layer_subgraph .opening_nodes is not None and len (layer_subgraph .opening_nodes ) > 0 :
562565 unprocessed_linear_nodes -= (
563566 set (layer_subgraph .opening_nodes )
@@ -823,6 +826,7 @@ def get_weight_shape(node: Node, dim: Optional[int] = None) -> Optional[Union[in
823826def get_layer_after_linear_node (
824827 linear_nodes : List [Node ],
825828 terminating_indices : List [int ],
829+ embd : Optional [int ] = None ,
826830 match_on_shapes : bool = True ,
827831 enforce_strict_linear_history : bool = True ,
828832) -> LayerSubgraph :
@@ -897,8 +901,9 @@ def filter_condition(node: Node, embd: Optional[int] = None, dim: Optional[int]
897901 layer_type = LayerType .UNKNOWN ,
898902 )
899903 if match_on_shapes :
900- # get embedding size of the opening linear node
901- embd = get_weight_shape (linear_nodes [start_lin_index ], dim = - 1 )
904+ if embd is None :
905+ # get embedding size of the opening linear node
906+ embd = get_weight_shape (linear_nodes [start_lin_index ], dim = - 1 )
902907 # partial init boundary_condition and filter_condition
903908 boundary_condition = partial (boundary_condition , embd = embd , dim = 0 )
904909 filter_condition = partial (filter_condition , embd = embd , dim = 0 )
@@ -951,7 +956,12 @@ def filter_condition(node: Node, embd: Optional[int] = None, dim: Optional[int]
951956 ssm_nodes = list (filtered_nodes (interior_nodes , is_any_ssm_op ))
952957 attention_nodes = list (filtered_nodes (interior_nodes , is_any_attention_op ))
953958 intermediate_lin_nodes = list (filtered_nodes (interior_nodes , is_any_lin_op ))
954- intermediate_weight_nodes = list (filtered_nodes (interior_nodes , is_weight_node ))
959+ intermediate_weight_nodes = list (
960+ filtered_nodes (
961+ interior_nodes , lambda n : is_weight_node (n ) and
962+ not is_any_lin_op (list (n .users )[0 ])
963+ )
964+ )
955965
956966 layer_type = LayerType .MLP
957967 min_local_shape = 1
0 commit comments