22from graph_net .imp_util import load_module
33from graph_net .tensor_meta import TensorMeta
44from typing import Callable
5+ import functools
56import copy
67import sys
78import os
@@ -36,6 +37,7 @@ def _make_config(
3637 model_runnable_predicator_class_name = "ModelRunner" ,
3738 model_runnable_predicator_config = None ,
3839 model_path_prefix = "" ,
40+ resume = False ,
3941 ):
4042 if data_input_predicator_config is None :
4143 data_input_predicator_config = {}
@@ -49,10 +51,23 @@ def _make_config(
4951 "model_runnable_predicator_class_name" : model_runnable_predicator_class_name ,
5052 "model_runnable_predicator_config" : model_runnable_predicator_config ,
5153 "model_path_prefix" : model_path_prefix ,
54+ "resume" : resume ,
5255 }
5356
5457 def __call__ (self , model_path ):
5558 model_path = os .path .join (self .config ["model_path_prefix" ], model_path )
59+ print (f"{ model_path = } " )
60+ cstr_path = os .path .join (model_path , "input_tensor_constraints.py" )
61+ if (
62+ self .config ["resume" ]
63+ and os .path .exists (cstr_path )
64+ and DynamicDimConstraints .kSymbols in open (cstr_path ).read ()
65+ ):
66+ module = load_module (cstr_path )
67+ symbols = getattr (module , DynamicDimConstraints .kSymbols )
68+ if len (symbols ) > 0 :
69+ return
70+
5671 tensor_metas = self ._get_tensor_metas (model_path )
5772 dyn_dim_cstr = make_dyn_dim_cstr_from_tensor_metas (tensor_metas )
5873
@@ -111,6 +126,11 @@ def update_tensor_metas_by_dyn_dim_cstr(
111126 assert len (tensor_metas ) == len (input_shapes )
112127 for i , tensor_meta in enumerate (tensor_metas ):
113128 tensor_meta .shape = input_shapes [i ]
129+ if tensor_meta .data is not None :
130+ assert isinstance (tensor_meta .data , (list , tuple ))
131+ size = functools .reduce (lambda a , b : a * b , tensor_meta .shape , 1 )
132+ doubled_data = [* tensor_meta .data , * tensor_meta .data ]
133+ tensor_meta .data = doubled_data [:size ]
114134
115135
116136def make_dyn_dim_cstr_from_tensor_metas (tensor_metas : list [TensorMeta ]):
@@ -152,7 +172,11 @@ def dumpy_filter_fn(input_name, input_idx, axis, dim):
152172 cur_dyn_dim_cstr = copy .deepcopy (dyn_dim_cstr )
153173
154174 def filter_fn (input_name , input_idx , axis , dim ):
155- return is_data_input (input_name ) and dim == picked_dim
175+ return (
176+ is_data_input (input_name )
177+ and dim == picked_dim
178+ and (dim > 1 or axis == 0 )
179+ )
156180
157181 symbol = cur_dyn_dim_cstr .symbolize (filter_fn )
158182 if symbol is None :
0 commit comments