diff --git a/graph_net/config/small_sample_list_for_get_fusible_subgraph.txt b/graph_net/config/small_sample_list_for_get_fusible_subgraph.txt deleted file mode 100644 index 3ea9a1a9f..000000000 --- a/graph_net/config/small_sample_list_for_get_fusible_subgraph.txt +++ /dev/null @@ -1,10 +0,0 @@ -#samples/timm/crossvit_small_240.in1k -#samples/timm/poolformerv2_s12.sail_in1k -#samples/timm/regnety_080.pycls_in1k -#samples/timm/dla46x_c.in1k -#samples/timm/mobilenetv1_100.ra4_e3600_r224_in1k -samples/timm/efficientnetv2_rw_s.ra2_in1k -samples/timm/vit_base_patch16_rope_ape_224.naver_in1k -#samples/timm/fastvit_t8.apple_dist_in1k -#samples/timm/test_byobnet.r160_in1k -#samples/timm/mambaout_base.in1k diff --git a/graph_net/dimension_generalizer.py b/graph_net/dimension_generalizer.py index 804b41637..da754d6d8 100644 --- a/graph_net/dimension_generalizer.py +++ b/graph_net/dimension_generalizer.py @@ -242,5 +242,7 @@ def update_tensor_metas_by_dyn_dim_cstr( if tensor_meta.data is not None: assert isinstance(tensor_meta.data, (list, tuple)) size = functools.reduce(lambda a, b: a * b, tensor_meta.shape, 1) - doubled_data = [*tensor_meta.data, *tensor_meta.data] - tensor_meta.data = doubled_data[:size] + extended_tensor_data = list(tensor_meta.data) + while len(extended_tensor_data) < size: + extended_tensor_data.extend(extended_tensor_data) + tensor_meta.data = extended_tensor_data[:size] diff --git a/graph_net/tools/apply_dim_gen_passes.sh b/graph_net/tools/apply_dim_gen_passes.sh index efeabbf9d..691b9858e 100755 --- a/graph_net/tools/apply_dim_gen_passes.sh +++ b/graph_net/tools/apply_dim_gen_passes.sh @@ -3,21 +3,19 @@ GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print( os.path.dirname(graph_net.__file__))") -python3 -m graph_net.model_path_handler \ +python3 -m graph_net.apply_sample_pass \ --model-path-list $GRAPH_NET_ROOT/config/torch_samples_list.txt \ - --handler-config=$(base64 -w 0 < bool: def _node_need_rewrite(self, node) -> bool: if not (node.op == "call_method"): return False - if not (node.op == "expand"): + if not (node.target == "expand"): return False input_tensor_node = node.args[0] input_meta = input_tensor_node.meta.get("tensor_meta") diff --git a/graph_net/torch/dim_gen_passes/non_batch_call_method_expand_pass.py b/graph_net/torch/dim_gen_passes/non_batch_call_method_expand_pass.py index 541bee1d5..b362e6fda 100644 --- a/graph_net/torch/dim_gen_passes/non_batch_call_method_expand_pass.py +++ b/graph_net/torch/dim_gen_passes/non_batch_call_method_expand_pass.py @@ -1,4 +1,3 @@ -import torch import torch.fx as fx from graph_net.torch.dim_gen_passes import DimensionGeneralizationPass from collections import namedtuple diff --git a/graph_net/torch/sym_dim_reifiers/naive_nlp_sym_dim_reifier.py b/graph_net/torch/sym_dim_reifiers/naive_nlp_sym_dim_reifier.py index c16295209..53fbd8369 100644 --- a/graph_net/torch/sym_dim_reifiers/naive_nlp_sym_dim_reifier.py +++ b/graph_net/torch/sym_dim_reifiers/naive_nlp_sym_dim_reifier.py @@ -19,7 +19,7 @@ def match(self) -> bool: return sym_shapes_str in self._get_map_nlp_sym_shapes_str2reifier() def reify(self): - assert self.need_reify() + assert self.match() sym_shapes_str = self.dyn_dim_cstrs.serialize_symbolic_input_shapes_to_str() reifier = self._get_map_nlp_sym_shapes_str2reifier()[sym_shapes_str] return reifier(self)