Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions graph_net/config/small_sample_list_for_get_fusible_subgraph.txt

This file was deleted.

6 changes: 4 additions & 2 deletions graph_net/dimension_generalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,5 +242,7 @@ def update_tensor_metas_by_dyn_dim_cstr(
if tensor_meta.data is not None:
assert isinstance(tensor_meta.data, (list, tuple))
size = functools.reduce(lambda a, b: a * b, tensor_meta.shape, 1)
doubled_data = [*tensor_meta.data, *tensor_meta.data]
tensor_meta.data = doubled_data[:size]
extended_tensor_data = list(tensor_meta.data)
while len(extended_tensor_data) < size:
extended_tensor_data.extend(extended_tensor_data)
tensor_meta.data = extended_tensor_data[:size]
24 changes: 11 additions & 13 deletions graph_net/tools/apply_dim_gen_passes.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,19 @@
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
os.path.dirname(graph_net.__file__))")

python3 -m graph_net.model_path_handler \
python3 -m graph_net.apply_sample_pass \
--model-path-list $GRAPH_NET_ROOT/config/torch_samples_list.txt \
--handler-config=$(base64 -w 0 <<EOF
--sample-pass-file-path "$GRAPH_NET_ROOT/dimension_generalizer.py" \
--sample-pass-class-name "ApplyDimGenPasses" \
--sample-pass-config $(base64 -w 0 <<EOF
{
"handler_path": "$GRAPH_NET_ROOT/dimension_generalizer.py",
"handler_class_name": "ApplyDimGenPasses",
"handler_config": {
"resume": false,
"output_dir": "/tmp/dimension_generalized_samples",
"model_path_prefix": "$GRAPH_NET_ROOT/../",
"dimension_generalizer_filepath": "$GRAPH_NET_ROOT/torch/static_to_dynamic.py",
"dimension_generalizer_class_name": "StaticToDynamic",
"limits_handled_models": 10,
"last_model_log_file": "/tmp/a.py"
}
"resume": false,
"output_dir": "/tmp/dimension_generalized_samples",
"model_path_prefix": "$GRAPH_NET_ROOT/../",
"dimension_generalizer_filepath": "$GRAPH_NET_ROOT/torch/static_to_dynamic.py",
"dimension_generalizer_class_name": "StaticToDynamic",
"limits_handled_models": 40,
"last_model_log_file": "/tmp/a.py"
}
EOF
)
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def need_rewrite(self, traced_module: fx.GraphModule) -> bool:
def _node_need_rewrite(self, node) -> bool:
if not (node.op == "call_method"):
return False
if not (node.op == "expand"):
if not (node.target == "expand"):
return False
input_tensor_node = node.args[0]
input_meta = input_tensor_node.meta.get("tensor_meta")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,11 @@ def create_new_node(node):
self.node_target(), args=new_node_args, kwargs=node.kwargs
)

return new_node
safe_arange_node = new_graph.call_function(
torch.remainder, args=(new_node, 512)
)

return safe_arange_node

for node in traced_module.graph.nodes:
val_map[node] = create_new_node(node)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import torch
import torch.fx as fx
from graph_net.torch.dim_gen_passes import DimensionGeneralizationPass
from collections import namedtuple
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def match(self) -> bool:
return sym_shapes_str in self._get_map_nlp_sym_shapes_str2reifier()

def reify(self):
assert self.need_reify()
assert self.match()
sym_shapes_str = self.dyn_dim_cstrs.serialize_symbolic_input_shapes_to_str()
reifier = self._get_map_nlp_sym_shapes_str2reifier()[sym_shapes_str]
return reifier(self)
Expand Down
Loading