Skip to content

Commit b203ea1

Browse files
authored
Add device_rewrite in subgraph dataset generating script. (#472)
1 parent e385a60 commit b203ea1

File tree

1 file changed

+36
-10
lines changed

1 file changed

+36
-10
lines changed

graph_net/tools/generate_subgraph_dataset.sh

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,26 @@
22
set -x
33

44
OP_NUM=${1:-64}
5-
GPU_ID=${2:-4}
5+
GPU_ID=${2:-0}
66

77
export CUDA_VISIBLE_DEVICES="${GPU_ID}"
8-
export PYTHONPATH=/work/GraphNet:/work/abstract_pass/Athena:$PYTHONPATH
98

109
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))")
1110

12-
DECOMPOSE_WORKSPACE=/work/graphnet_test_workspace/subgraph_dataset_20251221
11+
DECOMPOSE_WORKSPACE=/tmp/subgraph_dataset_workspace
1312
LEVEL_DECOMPOSE_WORKSPACE=$DECOMPOSE_WORKSPACE/decomposed_${OP_NUM}ops
1413
OP_NAMES_OUTPUT_DIR=${DECOMPOSE_WORKSPACE}/sample_op_names
1514
RANGE_DECOMPOSE_OUTPUT_DIR="${LEVEL_DECOMPOSE_WORKSPACE}/range_decompose"
1615
GRAPH_VAR_RENAME_OUTPUT_DIR=$LEVEL_DECOMPOSE_WORKSPACE/graph_var_renamed
1716
DEDUPLICATED_OUTPUT_DIR=$LEVEL_DECOMPOSE_WORKSPACE/deduplicated
17+
DEVICE_REWRITED_OUTPUT_DIR=$LEVEL_DECOMPOSE_WORKSPACE/device_rewrited
1818
UNITTESTS_OUTPUT_DIR=$LEVEL_DECOMPOSE_WORKSPACE/unittests
1919

2020
mkdir -p "$LEVEL_DECOMPOSE_WORKSPACE"
2121

2222
model_list="$GRAPH_NET_ROOT/graph_net/config/torch_samples_list.txt"
2323
range_decomposed_subgraph_list=${LEVEL_DECOMPOSE_WORKSPACE}/range_decomposed_subgraph_sample_list.txt
24+
device_rewrited_subgraph_list=${LEVEL_DECOMPOSE_WORKSPACE}/device_rewrited_subgraph_sample_list.txt
2425
deduplicated_subgraph_list=${LEVEL_DECOMPOSE_WORKSPACE}/deduplicated_subgraph_sample_list.txt
2526

2627
function generate_subgraph_list() {
@@ -64,6 +65,7 @@ function generate_split_point() {
6465
# level 5: 32, 64
6566
MIN_SEQ_OPS=$((${OP_NUM} / 2))
6667
MAX_SEQ_OPS=${OP_NUM}
68+
6769
echo ">>> [2] Generate split points for samples in ${model_list}."
6870
echo ">>> OP_NUM: ${OP_NUM}, MIN_SEQ_OPS: ${MIN_SEQ_OPS}, MAX_SEQ_OPS: ${MAX_SEQ_OPS}"
6971
echo ">>>"
@@ -90,7 +92,7 @@ function range_decompose() {
9092
"handler_path": "$GRAPH_NET_ROOT/graph_net/torch/graph_decomposer.py",
9193
"handler_class_name": "RangeDecomposerExtractor",
9294
"handler_config": {
93-
"resume": false,
95+
"resume": true,
9496
"model_path_prefix": "$GRAPH_NET_ROOT",
9597
"output_dir": "${RANGE_DECOMPOSE_OUTPUT_DIR}",
9698
"split_results_path": "$LEVEL_DECOMPOSE_WORKSPACE/split_results_${OP_NUM}.json",
@@ -135,21 +137,41 @@ function remove_duplicates() {
135137
--target-dir ${DEDUPLICATED_OUTPUT_DIR}
136138
}
137139

138-
function generate_unittests() {
139-
echo ">>> [6] Generate unittests for subgraph samples under ${DEDUPLICATED_OUTPUT_DIR}."
140+
function rewrite_device() {
141+
echo ">>> [6] Rewrite devices for subgraph samples under ${DEDUPLICATED_OUTPUT_DIR}."
140142
echo ">>>"
141143
python3 -m graph_net.model_path_handler \
142144
--model-path-list ${deduplicated_subgraph_list} \
143145
--handler-config=$(base64 -w 0 <<EOF
146+
{
147+
"handler_path": "$GRAPH_NET_ROOT/graph_net/torch/sample_passes/device_rewrite_sample_pass.py",
148+
"handler_class_name": "DeviceRewriteSamplePass",
149+
"handler_config": {
150+
"device": "cuda",
151+
"resume": true,
152+
"model_path_prefix": "${DEDUPLICATED_OUTPUT_DIR}",
153+
"output_dir": "${DEVICE_REWRITED_OUTPUT_DIR}"
154+
}
155+
}
156+
EOF
157+
)
158+
}
159+
160+
function generate_unittests() {
161+
echo ">>> [7] Generate unittests for subgraph samples under ${DEDUPLICATED_OUTPUT_DIR}."
162+
echo ">>>"
163+
python3 -m graph_net.model_path_handler \
164+
--model-path-list ${device_rewrited_subgraph_list} \
165+
--handler-config=$(base64 -w 0 <<EOF
144166
{
145167
"handler_path": "$GRAPH_NET_ROOT/graph_net/sample_pass/agent_unittest_generator.py",
146168
"handler_class_name": "AgentUnittestGeneratorPass",
147169
"handler_config": {
148170
"framework": "torch",
149-
"model_path_prefix": "${DEDUPLICATED_OUTPUT_DIR}",
171+
"model_path_prefix": "${DEVICE_REWRITED_OUTPUT_DIR}",
150172
"output_dir": "$UNITTESTS_OUTPUT_DIR",
151173
"device": "cuda",
152-
"generate_main": true,
174+
"generate_main": false,
153175
"try_run": true,
154176
"resume": true,
155177
"data_input_predicator_filepath": "$GRAPH_NET_ROOT/graph_net/torch/constraint_util.py",
@@ -163,7 +185,8 @@ EOF
163185
main() {
164186
timestamp=`date +%Y%m%d_%H%M`
165187
suffix="${OP_NUM}ops_${timestamp}"
166-
#generate_op_names 2>&1 | tee ${LEVEL_DECOMPOSE_WORKSPACE}/log_op_names_${suffix}.txt
188+
189+
generate_op_names 2>&1 | tee ${LEVEL_DECOMPOSE_WORKSPACE}/log_op_names_${suffix}.txt
167190
generate_split_point 2>&1 | tee ${LEVEL_DECOMPOSE_WORKSPACE}/log_split_point_${suffix}.txt
168191
range_decompose 2>&1 | tee ${LEVEL_DECOMPOSE_WORKSPACE}/log_range_decompose_${suffix}.txt
169192

@@ -172,7 +195,10 @@ main() {
172195
remove_duplicates 2>&1 | tee ${LEVEL_DECOMPOSE_WORKSPACE}/log_remove_duplicates_${suffix}.txt
173196

174197
generate_subgraph_list ${DEDUPLICATED_OUTPUT_DIR} ${deduplicated_subgraph_list}
175-
generate_unittests 2>&1 | tee ${LEVEL_DECOMPOSE_WORKSPACE}/log_generate_unittests_${suffix}.txt
198+
rewrite_device 2>&1 | tee ${LEVEL_DECOMPOSE_WORKSPACE}/log_rewrite_device_${suffix}.txt
199+
200+
generate_subgraph_list ${DEVICE_REWRITED_OUTPUT_DIR} ${device_rewrited_subgraph_list}
201+
generate_unittests 2>&1 | tee ${LEVEL_DECOMPOSE_WORKSPACE}/log_unittests_${suffix}.txt
176202
}
177203

178204
main

0 commit comments

Comments
 (0)