Skip to content

Commit 7c075bd

Browse files
committed
feat(agent): Implement automatic model extraction agent for Hackathon No.10
1 parent c24ad82 commit 7c075bd

File tree

10 files changed

+492
-0
lines changed

10 files changed

+492
-0
lines changed

demo_agent.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import os
2+
import sys
3+
4+
# Ensure we can import graph_net
5+
sys.path.append(os.getcwd())
6+
7+
from graph_net.agent import GraphNetAgent
8+
9+
def main():
10+
# Setup a local workspace
11+
# Use a writable directory instead of System32
12+
workspace = os.path.join(os.path.dirname(os.path.abspath(__file__)), "agent_workspace")
13+
print(f"Using workspace: {workspace}")
14+
15+
agent = GraphNetAgent(workspace=workspace)
16+
17+
# Use a small model for testing
18+
test_model = "prajjwal1/bert-tiny"
19+
20+
print(f"Processing model: {test_model}")
21+
success = agent.process_model(test_model)
22+
23+
if success:
24+
print("\n[SUCCESS] Agent successfully processed the model!")
25+
print(f"Check results in: {workspace}/downloads/{test_model.replace('/', '_')}/extracted_sample")
26+
else:
27+
print("\n[FAILURE] Agent failed to process the model.")
28+
29+
if __name__ == "__main__":
30+
main()

docs/agent_design.md

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# GraphNet 自动样本抽取 Agent 设计文档
2+
3+
## 1. 任务背景
4+
为了丰富 GraphNet 的样本库,我们需要从 Hugging Face (HF) 上自动下载模型,并将其转换为 GraphNet 可用的子图样本。目前这一过程需要人工编写 `run_model.py` 代码,效率较低。本 Agent 旨在自动化这一流程,实现从“HF 模型链接”到“GraphNet 样本提交”的端到端自动化。
5+
6+
## 2. 核心架构
7+
Agent 采用模块化设计,主要包含以下组件:
8+
9+
### 2.1 架构图
10+
```mermaid
11+
graph TD
12+
User[用户输入: HF Model ID] --> Manager[Agent Manager]
13+
Manager --> Fetcher[Model Fetcher]
14+
Fetcher -- 下载模型 --> Local[本地模型文件]
15+
Manager --> Analyzer[Model Analyzer]
16+
Analyzer -- 分析 config.json --> Meta[模型元数据(Input Shape/Dtype)]
17+
Manager --> Coder[Code Generator]
18+
Meta --> Coder
19+
Coder -- 生成代码 --> Script[run_model.py]
20+
Manager --> Extractor[Graph Extractor]
21+
Script --> Extractor
22+
Extractor -- 运行 & 抽图 --> Sample[GraphNet Sample]
23+
Manager --> Verifier[Sample Verifier]
24+
Sample --> Verifier
25+
Verifier -- 验证通过 --> Git[Git Submitter]
26+
```
27+
28+
### 2.2 模块说明
29+
30+
#### 1. Model Fetcher (`agent.fetcher`)
31+
- **功能**: 调用 `huggingface_hub` 下载模型快照。
32+
- **输入**: `model_id` (e.g., `bert-base-uncased`)
33+
- **输出**: 本地路径。
34+
35+
#### 2. Model Analyzer (`agent.analyzer`)
36+
- **功能**: 解析模型目录下的 `config.json``README.md`
37+
- **目标**: 推断模型的 `input_shape``input_dtype`。例如 BERT 通常需要 `input_ids` [batch, seq_len] (int64)。
38+
39+
#### 3. Code Generator (`agent.coder`)
40+
- **功能**: 生成 `run_model.py`
41+
- **策略**:
42+
- **Template Mode**: 针对常见架构(如 Bert, ResNet, GPT)使用预定义模板。
43+
- **LLM Mode (可选)**: 调用外部 LLM API 生成代码(预留接口)。
44+
45+
#### 4. Graph Extractor (`agent.extractor`)
46+
- **功能**: 在子进程中运行生成的 `run_model.py`
47+
- **依赖**: 复用 `graph_net.torch.run_model` 或直接调用脚本。
48+
49+
#### 5. Sample Verifier (`agent.verifier`)
50+
- **功能**: 检查生成的 `graph_net.json`, `model.py`, `input_meta.py` 是否存在且合法。
51+
52+
## 3. 接口设计
53+
54+
### `GraphNetAgent`
55+
```python
56+
class GraphNetAgent:
57+
def __init__(self, workspace: str, hf_token: str = None):
58+
self.workspace = workspace
59+
self.fetcher = HFFetcher(token=hf_token)
60+
self.analyzer = ConfigAnalyzer()
61+
self.coder = TemplateCoder()
62+
self.extractor = SubprocessExtractor()
63+
self.verifier = BasicVerifier()
64+
65+
def run(self, model_id: str) -> bool:
66+
# 1. Download
67+
model_dir = self.fetcher.download(model_id)
68+
69+
# 2. Analyze
70+
meta_info = self.analyzer.analyze(model_dir)
71+
72+
# 3. Generate Code
73+
code_path = self.coder.generate(model_dir, meta_info)
74+
75+
# 4. Extract
76+
output_dir = self.extractor.extract(code_path)
77+
78+
# 5. Verify
79+
return self.verifier.verify(output_dir)
80+
```
81+
82+
## 4. 目录结构
83+
```text
84+
graph_net/
85+
agent/
86+
__init__.py
87+
core.py # Agent 主逻辑
88+
fetcher.py # 下载模块
89+
analyzer.py # 分析模块
90+
coder/
91+
base.py
92+
template.py # 模板生成
93+
llm.py # LLM 生成 (Interface)
94+
extractor.py # 运行模块
95+
verifier.py # 验证模块
96+
```
97+
98+
## 5. 扩展性计划
99+
- 支持更多的 HF 任务类型(NLP, CV, Audio)。
100+
- 接入 DeepSeek/OpenAI API 提升代码生成成功率。
101+
- 自动化 PR 提交功能。

graph_net/agent/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .core import GraphNetAgent
2+
3+
__all__ = ["GraphNetAgent"]

graph_net/agent/analyzer.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import os
2+
import json
3+
import logging
4+
from typing import Dict, Any
5+
6+
class ConfigAnalyzer:
7+
def __init__(self):
8+
self.logger = logging.getLogger("ConfigAnalyzer")
9+
10+
def analyze(self, model_dir: str) -> Dict[str, Any]:
11+
"""
12+
Analyze config.json to infer input specifications.
13+
"""
14+
config_path = os.path.join(model_dir, "config.json")
15+
if not os.path.exists(config_path):
16+
raise FileNotFoundError(f"config.json not found in {model_dir}")
17+
18+
with open(config_path, "r", encoding="utf-8") as f:
19+
config = json.load(f)
20+
21+
architecture = config.get("architectures", ["Unknown"])[0]
22+
self.logger.info(f"Detected architecture: {architecture}")
23+
24+
# Heuristic rules
25+
meta_info = {
26+
"architecture": architecture,
27+
"input_shape": [1, 128], # Default batch size 1, seq len 128
28+
"input_dtype": "int64",
29+
"task_type": "nlp"
30+
}
31+
32+
# Refine based on architecture
33+
if "Bert" in architecture or "Roberta" in architecture:
34+
meta_info["input_names"] = ["input_ids", "attention_mask", "token_type_ids"]
35+
elif "Gpt" in architecture or "Llama" in architecture:
36+
meta_info["input_names"] = ["input_ids", "attention_mask"]
37+
elif "ResNet" in architecture or "ViT" in architecture:
38+
meta_info["task_type"] = "cv"
39+
meta_info["input_shape"] = [1, 3, 224, 224]
40+
meta_info["input_dtype"] = "float32"
41+
meta_info["input_names"] = ["pixel_values"]
42+
43+
return meta_info

graph_net/agent/coder/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .template import TemplateCoder

graph_net/agent/coder/template.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
import os
2+
import logging
3+
from typing import Dict, Any
4+
5+
class TemplateCoder:
6+
def __init__(self):
7+
self.logger = logging.getLogger("TemplateCoder")
8+
9+
def generate(self, model_dir: str, meta_info: Dict[str, Any]) -> str:
10+
"""
11+
Generate a python script to load the model and run extraction.
12+
"""
13+
script_content = self._create_script_content(model_dir, meta_info)
14+
15+
output_path = os.path.join(model_dir, "run_extraction.py")
16+
with open(output_path, "w", encoding="utf-8") as f:
17+
f.write(script_content)
18+
19+
return output_path
20+
21+
def _create_script_content(self, model_dir: str, meta_info: Dict[str, Any]) -> str:
22+
# Basic template for HF models
23+
input_names = meta_info.get("input_names", ["input_ids"])
24+
input_shape = meta_info.get("input_shape", [1, 128])
25+
input_dtype = meta_info.get("input_dtype", "int64")
26+
27+
# Construct input generation code
28+
input_gen_code = ""
29+
if meta_info["task_type"] == "nlp":
30+
input_gen_code += f"""
31+
# NLP Inputs
32+
input_ids = torch.randint(0, 100, {tuple(input_shape)}, dtype=torch.int64)
33+
attention_mask = torch.ones({tuple(input_shape)}, dtype=torch.int64)
34+
inputs = (input_ids, attention_mask)
35+
"""
36+
elif meta_info["task_type"] == "cv":
37+
input_gen_code += f"""
38+
# CV Inputs
39+
inputs = (torch.randn({tuple(input_shape)}, dtype=torch.float32),)
40+
"""
41+
42+
template = f"""
43+
import sys
44+
import os
45+
import torch
46+
from transformers import AutoModel, AutoConfig
47+
48+
# Ensure graph_net is in path
49+
sys.path.append(os.getcwd())
50+
51+
def main():
52+
model_path = r"{model_dir}"
53+
output_dir = r"{model_dir}/extracted_sample"
54+
55+
print(f"Loading model from {{model_path}}...")
56+
try:
57+
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
58+
model.eval()
59+
except Exception as e:
60+
print(f"Failed to load model: {{e}}")
61+
sys.exit(1)
62+
63+
print("Generating inputs...")
64+
{input_gen_code}
65+
66+
# Move to CUDA if available
67+
device = "cuda" if torch.cuda.is_available() else "cpu"
68+
model.to(device)
69+
inputs = tuple(t.to(device) for t in inputs)
70+
71+
print("Starting extraction...")
72+
# Setup environment variable for GraphNet workspace
73+
os.environ['GRAPH_NET_EXTRACT_WORKSPACE'] = output_dir
74+
75+
# Use the extract API from graph_net
76+
# extract(name, dynamic=True)(model) returns a compiled model
77+
# We need to run it once to trigger compilation and extraction
78+
from graph_net.torch.extractor import extract
79+
80+
compiled_model = extract(name="subgraph", dynamic=True)(model)
81+
82+
print("Running forward pass to trigger extraction...")
83+
with torch.no_grad():
84+
if isinstance(inputs, tuple):
85+
compiled_model(*inputs)
86+
elif isinstance(inputs, dict):
87+
compiled_model(**inputs)
88+
else:
89+
compiled_model(inputs)
90+
91+
print(f"Extraction complete. Results in {{output_dir}}")
92+
93+
if __name__ == "__main__":
94+
main()
95+
"""
96+
return template

graph_net/agent/core.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import os
2+
import logging
3+
from typing import Optional
4+
5+
from .fetcher import HFFetcher
6+
from .analyzer import ConfigAnalyzer
7+
from .coder.template import TemplateCoder
8+
from .extractor import SubprocessExtractor
9+
from .verifier import BasicVerifier
10+
11+
class GraphNetAgent:
12+
def __init__(self, workspace: str, hf_token: Optional[str] = None):
13+
"""
14+
Initialize the GraphNet Agent.
15+
16+
Args:
17+
workspace (str): Directory where models and samples will be stored.
18+
hf_token (str, optional): Hugging Face API token.
19+
"""
20+
self.workspace = os.path.abspath(workspace)
21+
os.makedirs(self.workspace, exist_ok=True)
22+
23+
self.logger = logging.getLogger("GraphNetAgent")
24+
self.logger.setLevel(logging.INFO)
25+
26+
# Initialize components
27+
self.fetcher = HFFetcher(self.workspace, token=hf_token)
28+
self.analyzer = ConfigAnalyzer()
29+
self.coder = TemplateCoder()
30+
self.extractor = SubprocessExtractor(self.workspace)
31+
self.verifier = BasicVerifier()
32+
33+
def process_model(self, model_id: str) -> bool:
34+
"""
35+
Process a single model: Download -> Analyze -> Generate Code -> Extract -> Verify.
36+
37+
Args:
38+
model_id (str): Hugging Face model ID (e.g. 'bert-base-uncased')
39+
40+
Returns:
41+
bool: True if successful, False otherwise.
42+
"""
43+
self.logger.info(f"Starting process for model: {model_id}")
44+
45+
try:
46+
# 1. Download Model
47+
self.logger.info("Step 1: Downloading model...")
48+
model_dir = self.fetcher.download(model_id)
49+
self.logger.info(f"Model downloaded to: {model_dir}")
50+
51+
# 2. Analyze Model Config
52+
self.logger.info("Step 2: Analyzing model config...")
53+
meta_info = self.analyzer.analyze(model_dir)
54+
self.logger.info(f"Analysis result: {meta_info}")
55+
56+
# 3. Generate Running Script
57+
self.logger.info("Step 3: Generating run_model.py...")
58+
script_path = self.coder.generate(model_dir, meta_info)
59+
self.logger.info(f"Script generated at: {script_path}")
60+
61+
# 4. Extract Subgraph
62+
self.logger.info("Step 4: Extracting subgraph...")
63+
output_dir = self.extractor.extract(script_path, model_id)
64+
self.logger.info(f"Extraction output dir: {output_dir}")
65+
66+
# 5. Verify Result
67+
self.logger.info("Step 5: Verifying result...")
68+
is_valid = self.verifier.verify(output_dir)
69+
70+
if is_valid:
71+
self.logger.info(f"SUCCESS: Model {model_id} processed successfully.")
72+
return True
73+
else:
74+
self.logger.error(f"FAILURE: Verification failed for {model_id}.")
75+
return False
76+
77+
except Exception as e:
78+
self.logger.error(f"Error processing {model_id}: {str(e)}")
79+
import traceback
80+
self.logger.error(traceback.format_exc())
81+
return False

0 commit comments

Comments
 (0)