22set -x
33
44OP_NUM=${1:- 64}
5- GPU_ID=${2:- 4 }
5+ GPU_ID=${2:- 0 }
66
77export CUDA_VISIBLE_DEVICES=" ${GPU_ID} "
8- export PYTHONPATH=/work/GraphNet:/work/abstract_pass/Athena:$PYTHONPATH
98
109GRAPH_NET_ROOT=$( python3 -c " import graph_net; import os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))" )
1110
12- DECOMPOSE_WORKSPACE=/work/graphnet_test_workspace/subgraph_dataset_20251221
11+ DECOMPOSE_WORKSPACE=/tmp/subgraph_dataset_workspace
1312LEVEL_DECOMPOSE_WORKSPACE=$DECOMPOSE_WORKSPACE /decomposed_${OP_NUM} ops
1413OP_NAMES_OUTPUT_DIR=${DECOMPOSE_WORKSPACE} /sample_op_names
1514RANGE_DECOMPOSE_OUTPUT_DIR=" ${LEVEL_DECOMPOSE_WORKSPACE} /range_decompose"
1615GRAPH_VAR_RENAME_OUTPUT_DIR=$LEVEL_DECOMPOSE_WORKSPACE /graph_var_renamed
1716DEDUPLICATED_OUTPUT_DIR=$LEVEL_DECOMPOSE_WORKSPACE /deduplicated
17+ DEVICE_REWRITED_OUTPUT_DIR=$LEVEL_DECOMPOSE_WORKSPACE /device_rewrited
1818UNITTESTS_OUTPUT_DIR=$LEVEL_DECOMPOSE_WORKSPACE /unittests
1919
2020mkdir -p " $LEVEL_DECOMPOSE_WORKSPACE "
2121
2222model_list=" $GRAPH_NET_ROOT /graph_net/config/torch_samples_list.txt"
2323range_decomposed_subgraph_list=${LEVEL_DECOMPOSE_WORKSPACE} /range_decomposed_subgraph_sample_list.txt
24+ device_rewrited_subgraph_list=${LEVEL_DECOMPOSE_WORKSPACE} /device_rewrited_subgraph_sample_list.txt
2425deduplicated_subgraph_list=${LEVEL_DECOMPOSE_WORKSPACE} /deduplicated_subgraph_sample_list.txt
2526
2627function generate_subgraph_list() {
@@ -64,6 +65,7 @@ function generate_split_point() {
6465 # level 5: 32, 64
6566 MIN_SEQ_OPS=$(( ${OP_NUM} / 2 ))
6667 MAX_SEQ_OPS=${OP_NUM}
68+
6769 echo " >>> [2] Generate split points for samples in ${model_list} ."
6870 echo " >>> OP_NUM: ${OP_NUM} , MIN_SEQ_OPS: ${MIN_SEQ_OPS} , MAX_SEQ_OPS: ${MAX_SEQ_OPS} "
6971 echo " >>>"
@@ -90,7 +92,7 @@ function range_decompose() {
9092 "handler_path": "$GRAPH_NET_ROOT /graph_net/torch/graph_decomposer.py",
9193 "handler_class_name": "RangeDecomposerExtractor",
9294 "handler_config": {
93- "resume": false ,
95+ "resume": true ,
9496 "model_path_prefix": "$GRAPH_NET_ROOT ",
9597 "output_dir": "${RANGE_DECOMPOSE_OUTPUT_DIR} ",
9698 "split_results_path": "$LEVEL_DECOMPOSE_WORKSPACE /split_results_${OP_NUM} .json",
@@ -135,21 +137,41 @@ function remove_duplicates() {
135137 --target-dir ${DEDUPLICATED_OUTPUT_DIR}
136138}
137139
138- function generate_unittests () {
139- echo " >>> [6] Generate unittests for subgraph samples under ${DEDUPLICATED_OUTPUT_DIR} ."
140+ function rewrite_device () {
141+ echo " >>> [6] Rewrite devices for subgraph samples under ${DEDUPLICATED_OUTPUT_DIR} ."
140142 echo " >>>"
141143 python3 -m graph_net.model_path_handler \
142144 --model-path-list ${deduplicated_subgraph_list} \
143145 --handler-config=$( base64 -w 0 << EOF
146+ {
147+ "handler_path": "$GRAPH_NET_ROOT /graph_net/torch/sample_passes/device_rewrite_sample_pass.py",
148+ "handler_class_name": "DeviceRewriteSamplePass",
149+ "handler_config": {
150+ "device": "cuda",
151+ "resume": true,
152+ "model_path_prefix": "${DEDUPLICATED_OUTPUT_DIR} ",
153+ "output_dir": "${DEVICE_REWRITED_OUTPUT_DIR} "
154+ }
155+ }
156+ EOF
157+ )
158+ }
159+
160+ function generate_unittests() {
161+ echo " >>> [7] Generate unittests for subgraph samples under ${DEDUPLICATED_OUTPUT_DIR} ."
162+ echo " >>>"
163+ python3 -m graph_net.model_path_handler \
164+ --model-path-list ${device_rewrited_subgraph_list} \
165+ --handler-config=$( base64 -w 0 << EOF
144166{
145167 "handler_path": "$GRAPH_NET_ROOT /graph_net/sample_pass/agent_unittest_generator.py",
146168 "handler_class_name": "AgentUnittestGeneratorPass",
147169 "handler_config": {
148170 "framework": "torch",
149- "model_path_prefix": "${DEDUPLICATED_OUTPUT_DIR } ",
171+ "model_path_prefix": "${DEVICE_REWRITED_OUTPUT_DIR } ",
150172 "output_dir": "$UNITTESTS_OUTPUT_DIR ",
151173 "device": "cuda",
152- "generate_main": true ,
174+ "generate_main": false ,
153175 "try_run": true,
154176 "resume": true,
155177 "data_input_predicator_filepath": "$GRAPH_NET_ROOT /graph_net/torch/constraint_util.py",
163185main () {
164186 timestamp=` date +%Y%m%d_%H%M`
165187 suffix=" ${OP_NUM} ops_${timestamp} "
166- # generate_op_names 2>&1 | tee ${LEVEL_DECOMPOSE_WORKSPACE}/log_op_names_${suffix}.txt
188+
189+ generate_op_names 2>&1 | tee ${LEVEL_DECOMPOSE_WORKSPACE} /log_op_names_${suffix} .txt
167190 generate_split_point 2>&1 | tee ${LEVEL_DECOMPOSE_WORKSPACE} /log_split_point_${suffix} .txt
168191 range_decompose 2>&1 | tee ${LEVEL_DECOMPOSE_WORKSPACE} /log_range_decompose_${suffix} .txt
169192
@@ -172,7 +195,10 @@ main() {
172195 remove_duplicates 2>&1 | tee ${LEVEL_DECOMPOSE_WORKSPACE} /log_remove_duplicates_${suffix} .txt
173196
174197 generate_subgraph_list ${DEDUPLICATED_OUTPUT_DIR} ${deduplicated_subgraph_list}
175- generate_unittests 2>&1 | tee ${LEVEL_DECOMPOSE_WORKSPACE} /log_generate_unittests_${suffix} .txt
198+ rewrite_device 2>&1 | tee ${LEVEL_DECOMPOSE_WORKSPACE} /log_rewrite_device_${suffix} .txt
199+
200+ generate_subgraph_list ${DEVICE_REWRITED_OUTPUT_DIR} ${device_rewrited_subgraph_list}
201+ generate_unittests 2>&1 | tee ${LEVEL_DECOMPOSE_WORKSPACE} /log_unittests_${suffix} .txt
176202}
177203
178204main
0 commit comments