Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions graph_net/torch/typical_sequence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""
Typical Sequence Extractor
Identify repeated subgraph patterns from extracted FX Graph and save them categorized.
"""
import argparse
import os
import sys
import json
import hashlib
import ast
import copy
from pathlib import Path
from typing import List, Dict, Tuple, Set, Any
from graph_net.torch.rp_expr.rp_expr_parser import RpExprParser
from graph_net.torch.rp_expr.rp_expr_util import (
MakeNestedIndexRangeFromLetsListTokenRpExpr,
)


def _get_leaf_model_pathes(src_model_path: str):
# Traverse all submodule (features.0, classifier.6) in src_model_path
return
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没有实现全的代码,一律 return TODO()



def _get_fx_graph(leaf_model_path: str):
# Load the GraphModule and extract its fx.Graph
return


def _get_fx_node(fx_graph: str):
# Traverse fx_graph.nodes to obtain all Nodes (splaceholder, call_function, output)
return


def encode_node_to_stmt_token_id(node: str):
# Node is encoded as token_id, representing a certain pattern
return


def SequenceUnittestsGenerator(
program_id: str, seq_stmts: List[str], dist_model_path: str
):
# Generate unittests for each sequence
return


def extract_typical_sequences(src_model_path: str, dist_model_path: str, dynamic=True):
# Extract fx_graphs from src_model_path, type: fx.graphmodule
fx_graphs = [
fx_graph
for leaf_model_path in _get_leaf_model_pathes(src_model_path)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

问:所提取出来的 fx_graphs 为什么是多个的?
答:src_model_path 这个路径下可能不止有一个计算图,也许是整个 samples 目录塞进来。

for fx_graph in [_get_fx_graph(leaf_model_path)]
]

# Convert each fx.Graph into a sequence of stmt token_ids
stmt_token_ids = [
[
stmt_token_id
for node in _get_fx_node(fx_graph)
for stmt_token_id in [encode_node_to_stmt_token_id(node)]
]
for fx_graph in fx_graphs
]

# Extract typical subgraph patterns by RpExprParser
parser = RpExprParser(window_size=64)
lets_list_rp_expr, token_id2primitive_id = parser(stmt_token_ids)

# Map the pattern back to the original FX Node range
trees = MakeNestedIndexRangeFromLetsListTokenRpExpr(lets_list_rp_expr)

# Filter by length
ranges_list = [
tree.FilterSubTreeRangeBySize(min_length=2, max_length=33) for tree in trees
]

# get (program_id, seq_stmts) pair
program_seq_stmts_list = (
(program_id, seq_stmts)
for ranges, pair in zip(ranges_list, stmt_token_ids)
for start, end in ranges
for program_id, origin_seq_stmts in [pair]
for seq_stmts in [origin_seq_stmts[start:end]]
)

# Generate unittests for each sequence
# Each folder: subgraph_<hash_id>_<count_id>.py
SequenceUnittestsGenerator(program_id, seq_stmts, dist_model_path)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

program_id, seq_stmts 这两个变量没有定义,肯定理解错了,应该是program_seq_stmts_list ,这现在是一个 generator。为了便于理解,你可以在这句话前加一行:

program_seq_stmts_list = list(program_seq_stmts_list)


return


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--src_model_path", type=str, default="samples/torchvision/alexnet"
)
parser.add_argument("--dist_model_path", type=str)
args = parser.parse_args()

# mkdir for dist model path; diff with src model path
# Except: src_model_path: GraphNet/samples/torchvision/alexnet
# dist_model_path: GraphNet/subgraphs/torchvision/alexnet
if not args.dist_model_path:
components = args.src_model_path.split(os.sep)
try:
samples_index = components.index("samples")
components[samples_index] = "subgraphs"
except ValueError:
print(
"Warning: 'samples' not found in src_model_path. Using default structure for dist_model_path."
)
components.insert(2, "subgraphs")
args.dist_model_path = os.sep.join(components)
os.makedirs(os.path.dirname(args.dist_model_path), exist_ok=True)

# extract_typical_sequences(args.src_model_path, args.dist_model_path)

print("Source model path:", args.src_model_path)
print("Distribution model path:", args.dist_model_path)
Loading