Skip to content

Commit d2d9e0b

Browse files
committed
Merge branch 'develop' of github.com:PaddlePaddle/GraphNet into develop
2 parents 6059328 + 7c43d38 commit d2d9e0b

File tree

788 files changed

+843860
-20
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

788 files changed

+843860
-20
lines changed

.github/PULL_REQUEST_TEMPLATE.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
### PR Category
2+
<!-- One of [ New Sample | Feature Enhancement | Bug Fix | Other ] -->
3+
4+
5+
### Description
6+
<!-- Describe what you’ve done -->

CONTRIBUTE_TUTORIAL.md

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ from graph_net.torch.extractor import extract
4848
4949
def run_model(name: str, device_str: str) -> None:
5050
"""
51-
Run computation graph extraction for the specified model.
51+
Run computational graph extraction for the specified model.
5252
5353
Args:
5454
name (str): Model name (e.g., 'resnet50', 'vit_b_16', 'bert-base-uncased').
@@ -222,14 +222,14 @@ Re-run `extract(name)(model)(input_data)` to complete extraction.
222222
* Insert hooks at specific layers by passing them into `wrapped = extract(...)(model, hooks=...)`.
223223
224224
225-
### 3. Extracting the Computation Graph
225+
### 3. Extracting the Computational Graph
226226
227227
1. **Run extract**
228228
229229
Execute scripts with `@graph_net.torch.extract` or `@graph_net.paddle.extract`. For example:
230230
231231
```bash
232-
# Extract the ResNet‑18 computation graph
232+
# Extract the ResNet‑18 computational graph
233233
python -m graph_net.test.vision_model_test
234234
```
235235
@@ -241,7 +241,7 @@ Expected output will be saved under `$GRAPH_NET_EXTRACT_WORKSPACE`.
241241
python -m graph_net.torch.validate --model-path $GRAPH_NET_EXTRACT_WORKSPACE/model_name
242242
```
243243
244-
`validate` checks if the extracted graph meets the Dataset Construction Constraints. If success, you’re ready to submit.
244+
`validate` checks if the extracted graph meets the Dataset Construction Constraints. If success, you’re ready to continue.
245245
246246
247247
### 4. Submitting the Extracted Graph
@@ -263,16 +263,17 @@ python -m graph_net.pack --output /path/to/output.zip --clear-after-pack True
263263
264264
This API:
265265
266-
a. Packages all files under `$GRAPH_NET_EXTRACT_WORKSPACE` into `/path/to/output.zip`
266+
a. Packages all files under `$GRAPH_NET_EXTRACT_WORKSPACE` into `/path/to/output.zip` (You can set it to `GraphNet/samples`)
267267
268268
b. Clears the workspace if `--clear-after-pack` is `True`
269269
270-
Note: If third-party ops are used, contributors must include them manually in the graph directory. As long as `validate` passes, no specific folder structure is required.
270+
Note: If third-party ops are used, contributors must include them manually in the package. As long as `validate` passes, no specific folder structure is required.
271271
272272
3. **Commit the changes**
273273
274+
Move the packaged computational graph in the previous step to **samples** directory and commit.
274275
```bash
275-
git add <new files>
276+
git add <the packaged computational graph>
276277
git commit -m "Description"
277278
```
278279

CONTRIBUTE_TUTORIAL_cn.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ python -m graph_net.test.vision_model_test
234234
```bash
235235
python -m graph_net.torch.validate --model-path $GRAPH_NET_EXTRACT_WORKSPACE/model_name
236236
```
237-
`validate` 验证您刚刚抽取的计算图符合Dataset Construction Constraints,如果结果为Success,则可以提交
237+
`validate` 验证您刚刚抽取的计算图符合Dataset Construction Constraints,如果结果为Success,则可以继续
238238

239239

240240

@@ -254,16 +254,17 @@ python -m graph_net.pack --output /path/to/output.zip --clear-after-pack True
254254
```
255255
该API的功能为:
256256

257-
a. 打包`$GRAPH_NET_EXTRACT_WORKSPACE`下的所有文件到`/path/to/output.zip`
257+
a. 打包`$GRAPH_NET_EXTRACT_WORKSPACE`下的所有文件到`/path/to/output.zip` (可以设置到`GraphNet/samples`
258258

259259
b. 若`--clear-after-pack``True`,则打包后清空`$GRAPH_NET_EXTRACT_WORKSPACE`
260260

261-
请注意,如果有第三方算子,需要贡献者自行打包到计算图目录内。目前没有特别规定存放的目录结构,但只要通过了validate环节,就可以达到验收标准。
261+
请注意,如果有第三方算子,需要贡献者自行打包到计算图压缩包内。目前没有特别规定存放的目录结构,但只要通过了validate环节,就可以达到验收标准。
262262

263263
3. **提交修改**
264264

265+
移动上一步打包完成的计算图压缩包到**samples**目录,然后提交。
265266
```bash
266-
git add <新增的文件>
267+
git add <计算图压缩包>
267268
git commit -m "描述"
268269
```
269270
4. **推送分支到远程**(你的 Fork 仓库)

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ With GraphNet, users can:
88
3. Train AI‑for‑Systems models to automatically generate compiler optimization passes.
99

1010

11-
Dataset Construction Constraints:
11+
### Dataset Construction Constraints:
1212
1. Dynamic graphs must execute correctly.
1313
2. Each computation graph should include a standardized method for measuring performance.
1414
3. Graphs and their corresponding Python code must support serialization and deserialization.
@@ -91,10 +91,10 @@ Once you have packaged these extracted computation graphs, submit them to the Gr
9191
<table>
9292
<tr>
9393
<td align="center">
94-
<img width="190" height="220" src="https://github.com/user-attachments/assets/31b4f0ba-417e-48b6-a860-124d74bd6643" />
94+
<img width="200" src="https://github.com/user-attachments/assets/30f034dd-f7d9-49f5-bae8-30ba2ac9c6b4" />
9595
</td>
9696
<td align="center">
97-
<img width="190" height="220" src="https://github.com/user-attachments/assets/140fa03e-36ef-44bf-8d9a-ca65c83b0139" />
97+
<img width="200" src="https://github.com/user-attachments/assets/140fa03e-36ef-44bf-8d9a-ca65c83b0139" />
9898
</td>
9999
</tr>
100100
</table>

README_cn.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ GraphNet —— 一个面向编译器开发的大规模数据集,旨在为研
88
2. 方便已有编译器做回归测试
99
3. 训练AI-for-system模型以自动生成编译器优化Pass
1010

11-
数据集构建约束:
11+
### 数据集构建约束:
1212

1313
1. 动态图能正常运行
1414
2. 每份计算图有通用方法测定性能指标
@@ -90,7 +90,7 @@ python -m graph_net.config --global\
9090
<table>
9191
<tr>
9292
<td align="center">
93-
<img width="190" height="220" src="https://github.com/user-attachments/assets/31b4f0ba-417e-48b6-a860-124d74bd6643" />
93+
<img width="190" height="220" src="https://github.com/user-attachments/assets/1a42cceb-f026-44a6-acbe-dee810410893" />
9494
</td>
9595
<td align="center">
9696
<img width="190" height="220" src="https://github.com/user-attachments/assets/140fa03e-36ef-44bf-8d9a-ca65c83b0139" />
@@ -100,4 +100,6 @@ python -m graph_net.config --global\
100100
</div>
101101

102102
## 开源协议
103-
[MIT License](LICENSE)
103+
[MIT License](LICENSE)
104+
105+

graph_net/test/bert_model_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def create_model():
2424
inputs = {k: v.to(device) for k, v in inputs.items()}
2525

2626
model = create_model()
27-
model = graph_net.torch.extract(name=get_model_name())(model)
27+
model = graph_net.torch.extract(name=get_model_name(), dynamic=True)(model)
2828

2929
print("Running inference...")
3030
output = model(**inputs)

graph_net/test/vision_model_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
model.to(device)
2929
normalized_input = normalized_input.to(device)
3030

31-
model = graph_net.torch.extract(name="resnet18")(model)
31+
model = graph_net.torch.extract(name="resnet18", dynamic=True)(model)
3232

3333
print("Running inference...")
3434
print("Input shape:", normalized_input.shape)

graph_net/torch/utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,23 @@
88
import argparse
99
import importlib
1010
import inspect
11+
import math
1112

1213

1314
def apply_templates(forward_code: str) -> str:
1415
tab = " "
1516
forward_code = f"\n{tab}".join(forward_code.split("\n"))
16-
return f"import torch\n\nclass GraphModule(torch.nn.Module):\n{tab}{forward_code}"
17+
imports = "import torch"
18+
if "device" in forward_code:
19+
imports += "\n\nfrom torch import device"
20+
return f"{imports}\n\nclass GraphModule(torch.nn.Module):\n{tab}{forward_code}"
1721

1822

1923
def get_limited_precision_float_str(value):
2024
if not isinstance(value, float):
2125
return value
26+
if not math.isfinite(value):
27+
return f'float("{value}")'
2228
return f"{value:.3f}"
2329

2430

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
ff10075e46eeae7615c783a3970e04ae79c1ea3f5b6020b167bec1e19489372f
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{
2+
"framework": "torch",
3+
"num_devices_required": 1,
4+
"num_nodes_required": 1
5+
}

0 commit comments

Comments
 (0)