@@ -535,6 +535,9 @@ def get_all_layer_subgraphs(gm: GraphModule) -> List[List[Node]]:
535535 assert gm .graph .nodes , "Graph is empty"
536536 layer_subgraphs = []
537537 linear_nodes = list (filtered_nodes (gm .graph .nodes , is_any_lin_op ))
538+
539+ # find the embedding size of this model. Extract it from the input of the first linear node.
540+ embd = get_weight_shape (linear_nodes [0 ], dim = - 1 )
538541 unprocessed_linear_nodes = set (linear_nodes )
539542 assert len (linear_nodes ) > 0 , "Could not find any linear nodes in the graph"
540543
@@ -546,7 +549,7 @@ def get_all_layer_subgraphs(gm: GraphModule) -> List[List[Node]]:
546549 # opening is the list of linear nodes
547550 # layer_subgraph is the list of nodes between the opening and closing linear nodes
548551 # closing is the last linear node in the layer
549- layer_subgraph = get_layer_after_linear_node (linear_nodes , terminating_indices )
552+ layer_subgraph = get_layer_after_linear_node (linear_nodes , terminating_indices , embd = embd )
550553 if layer_subgraph .opening_nodes is not None and len (layer_subgraph .opening_nodes ) > 0 :
551554 unprocessed_linear_nodes -= (
552555 set (layer_subgraph .opening_nodes )
@@ -812,6 +815,7 @@ def get_weight_shape(node: Node, dim: Optional[int] = None) -> Optional[Union[in
812815def get_layer_after_linear_node (
813816 linear_nodes : List [Node ],
814817 terminating_indices : List [int ],
818+ embd : Optional [int ] = None ,
815819 match_on_shapes : bool = True ,
816820 enforce_strict_linear_history : bool = True ,
817821) -> LayerSubgraph :
@@ -886,8 +890,9 @@ def filter_condition(node: Node, embd: Optional[int] = None, dim: Optional[int]
886890 layer_type = LayerType .UNKNOWN ,
887891 )
888892 if match_on_shapes :
889- # get embedding size of the opening linear node
890- embd = get_weight_shape (linear_nodes [start_lin_index ], dim = - 1 )
893+ if embd is None :
894+ # get embedding size of the opening linear node
895+ embd = get_weight_shape (linear_nodes [start_lin_index ], dim = - 1 )
891896 # partial init boundary_condition and filter_condition
892897 boundary_condition = partial (boundary_condition , embd = embd , dim = 0 )
893898 filter_condition = partial (filter_condition , embd = embd , dim = 0 )
@@ -940,7 +945,12 @@ def filter_condition(node: Node, embd: Optional[int] = None, dim: Optional[int]
940945 ssm_nodes = list (filtered_nodes (interior_nodes , is_any_ssm_op ))
941946 attention_nodes = list (filtered_nodes (interior_nodes , is_any_attention_op ))
942947 intermediate_lin_nodes = list (filtered_nodes (interior_nodes , is_any_lin_op ))
943- intermediate_weight_nodes = list (filtered_nodes (interior_nodes , is_weight_node ))
948+ intermediate_weight_nodes = list (
949+ filtered_nodes (
950+ interior_nodes , lambda n : is_weight_node (n ) and
951+ not is_any_lin_op (list (n .users )[0 ])
952+ )
953+ )
944954
945955 layer_type = LayerType .MLP
946956 min_local_shape = 1
0 commit comments