Skip to content

Commit dc0eb9a

Browse files
authored
statisticize in tensor symbolic shapes (#404)
* support checking model redundancy * revert change of vision_model_test * reformat python code. * reformat bert_model_test.py and utils.py * minor fix * fix failed check by comparing directories after os.path.realpath() * fix bugs in check_validate.sh * set dynamic=False in single_device_runner.py * reset graph hash * add robustness code for generating input tensor constraints * Introduce input_tensor_constraints.py using shape propagation logic. * support dimension generalization for torch.Tensor.view and torch.Tensor.reshape * 1) support dimension generalization for torch.Tensor.expand(); 2) fix bugs in generalization for torch.Tensor.view and torch.Tensor.reshape * dimension_generalization_passes * Refactored DimensionGeneralizationPass.__init__ to accept argument dim_axes_pairs, enabling targeted configuration for specific use cases * save dimension generalization pass names into graph_net.json * Generalize sequence dimension * more dimension generalization passes for token dimension * refactor parse_sole_graph_module * Enhance the performance of the input_tensor_constraints.py file generation process. * minor fix * more dimension generalization pass * fix some hint bugs * generate input_tensor_constraints.py * statisticize in tensor symbolic shapes
1 parent f2c56eb commit dc0eb9a

File tree

3 files changed

+57
-0
lines changed

3 files changed

+57
-0
lines changed
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from pathlib import Path
2+
from graph_net.dynamic_dim_constraints import DynamicDimConstraints
3+
import sympy
4+
5+
6+
class GetInTensorSymbolicShapes:
7+
def __init__(self, config):
8+
self.config = self.make_config(**config)
9+
10+
def make_config(self, model_path_prefix):
11+
return {
12+
"model_path_prefix": model_path_prefix,
13+
}
14+
15+
def __call__(self, model_path):
16+
original_model_path = Path(self.config["model_path_prefix"]) / model_path
17+
input_tensor_cstr_filepath = original_model_path / "input_tensor_constraints.py"
18+
if not input_tensor_cstr_filepath.exists():
19+
print(f"get-in-tensor-symbolic-shapes None {model_path}")
20+
return
21+
dyn_dim_cstrs = DynamicDimConstraints.unserialize_from_py_file(
22+
str(input_tensor_cstr_filepath)
23+
)
24+
dyn_dim_cstrs.symbol2example_value = {}
25+
dyn_dim_cstrs.input_shapes = sorted(
26+
[
27+
tuple(shape)
28+
for shape, name in dyn_dim_cstrs.input_shapes
29+
if any(isinstance(dim, sympy.Expr) for dim in shape)
30+
],
31+
key=str,
32+
)
33+
input_shapes_str = str(dyn_dim_cstrs.input_shapes).replace(" ", "")
34+
print(f"get-in-tensor-symbolic-shapes {input_shapes_str} {model_path}")
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#!/bin/bash
2+
3+
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
4+
os.path.dirname(graph_net.__file__))")
5+
6+
# input model path
7+
# model_runnable_predicator=ShapePropagatablePredicator
8+
model_runnable_predicator=ModelRunnablePredicator
9+
config_json_str=$(cat <<EOF
10+
{
11+
"handler_path": "$GRAPH_NET_ROOT/tools/_get_in_tensor_symbolic_shapes.py",
12+
"handler_class_name": "GetInTensorSymbolicShapes",
13+
"handler_config": {
14+
"model_path_prefix": "$GRAPH_NET_ROOT/../"
15+
}
16+
}
17+
EOF
18+
)
19+
CONFIG=$(echo $config_json_str | base64 -w 0)
20+
21+
python3 -m graph_net.model_path_handler --model-path-list $GRAPH_NET_ROOT/config/torch_samples_list.txt --handler-config=$CONFIG
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
#/bin/bash
2+
bash get_in_tensor_symbolic_shapes.sh | grep get-in-tensor-symbolic-shapes | awk '{print $2}' | sort | uniq -c | sort -nk1

0 commit comments

Comments
 (0)