Skip to content

Commit 9b48cfd

Browse files
authored
[cherry-pick][Dy2Stat]Support non-tensor type in input_spec (#33464) #34378
[Dy2Stat]Support non-tensor type in input_spec
1 parent dbc54d2 commit 9b48cfd

File tree

6 files changed

+296
-49
lines changed

6 files changed

+296
-49
lines changed

python/paddle/fluid/dygraph/dygraph_to_static/function_spec.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -193,14 +193,8 @@ def _verify_input_spec(self, input_spec):
193193
raise TypeError(
194194
"The type(input_spec) should be one of (tuple, list), but received {}.".
195195
format(type_name(input_spec)))
196-
input_spec = tuple(input_spec)
197-
for spec in flatten(input_spec):
198-
if not isinstance(spec, paddle.static.InputSpec):
199-
raise ValueError(
200-
"The type(elem) from input_spec should be `InputSpec`, but received {}.".
201-
format(type_name(spec)))
202196

203-
return input_spec
197+
return tuple(input_spec)
204198

205199
def __repr__(self):
206200
return "function: {}({}), input_spec: {}".format(
@@ -326,9 +320,8 @@ def check_type_and_len(input, spec, check_length=False):
326320
elif isinstance(input_spec, paddle.static.InputSpec):
327321
return input_spec
328322
else:
329-
raise TypeError(
330-
"The type(input_spec) should be a `InputSpec` or dict/list/tuple of it, but received {}.".
331-
type_name(input_spec))
323+
# NOTE(Aurelius84): Support non-Tensor type as input spec info
324+
return input_spec
332325

333326

334327
def replace_spec_empty_name(args_name, input_with_spec):

python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import six
2121
import textwrap
2222
import threading
23-
import warnings
2423
import weakref
2524

2625
from paddle.fluid import framework
@@ -314,7 +313,7 @@ def __call__(self, *args, **kwargs):
314313
# Here calls `warnings.warn` but not `logging_utils.warn` because by default warnings.warn(message)
315314
# will show up **only once**. StaticFunction.__call__ will run many times, it is appropriate to
316315
# display this warning message only once.
317-
warnings.warn(
316+
logging_utils.warn(
318317
"The decorator '@paddle.jit.to_static' does NOT work when setting ProgramTranslator.enable to False. "
319318
"We will just return dygraph output. If you would like to get static graph output, please call API "
320319
"ProgramTranslator.enable(True)")
@@ -481,6 +480,10 @@ def concrete_program_specify_input_spec(self, input_spec=None):
481480
# NOTE(chenweihang): we should always translated program based on the `input_spec`
482481
# decorated on forward if it is valid
483482
desired_input_spec = self._function_spec.input_spec
483+
if input_spec is not None:
484+
logging_utils.warn(
485+
"\n\nYou have specified `input_spec` both in function definition (higher priority) and `paddle.jit.save` (will be ignored.)\n\n\t Using: {}\n\n\t Ignore: {}\n".
486+
format(desired_input_spec, input_spec))
484487

485488
has_input_spec = (desired_input_spec is not None)
486489
if has_input_spec:
@@ -886,7 +889,7 @@ def func(x):
886889
if not self.enable_to_static:
887890
# Here calls `warnings.warn` but not `logging_utils.warn` because by default warnings.warn(message)
888891
# will show up **only once**.
889-
warnings.warn(
892+
logging_utils.warn(
890893
"The ProgramTranslator.get_output doesn't work when setting ProgramTranslator.enable to False. "
891894
"We will just return dygraph output. "
892895
"Please call ProgramTranslator.enable(True) if you would like to get static output."

python/paddle/fluid/dygraph/dygraph_to_static/utils.py

Lines changed: 63 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import textwrap
2828
import numpy as np
2929

30+
import paddle
3031
from paddle.fluid import unique_name
3132
from paddle.fluid.data_feeder import convert_dtype
3233

@@ -141,9 +142,9 @@ def make_hashable(x, error_msg=None):
141142
"""
142143
Makes input `x` hashable.
143144
144-
For some unhashable objects, such as `dict/list/np.ndarray`,applying hash function by using their values.
145+
For some unhashable objects, such as `dict/list/set/np.ndarray`,applying hash function by using their values.
145146
"""
146-
if isinstance(x, (tuple, list)):
147+
if isinstance(x, (tuple, list, set)):
147148
return tuple(map(make_hashable, x))
148149

149150
try:
@@ -1428,10 +1429,10 @@ def input_specs_compatible(src_input_specs, desired_input_specs):
14281429
Returns True if the two input specs are compatible, otherwise False.
14291430
14301431
args:
1431-
src_input_spec (list[InputSpec]|tuple(InputSpec)): list/tuple of
1432-
paddle.static.InputSpec
1433-
desired_input_specs (list[InputSpec]|tuple(InputSpec)): list/tuple of
1434-
paddle.static.InputSpec
1432+
src_input_spec (list or tuple[InputSpec et.al]): list/tuple of
1433+
paddle.static.InputSpec or int/str et.al
1434+
desired_input_specs (list or tuple[InputSpec et.al]): list/tuple of
1435+
paddle.static.InputSpec or int/str et.al
14351436
"""
14361437
len_specs = len(src_input_specs)
14371438
if len_specs != len(desired_input_specs):
@@ -1440,30 +1441,69 @@ def input_specs_compatible(src_input_specs, desired_input_specs):
14401441
for spec in src_input_specs:
14411442
if spec not in desired_input_specs:
14421443
return False
1443-
14441444
else:
1445-
for i in range(len_specs):
1446-
src_shape = src_input_specs[i].shape
1447-
other_shape = desired_input_specs[i].shape
1448-
len_shape = len(src_shape)
1449-
if len_shape != len(other_shape):
1450-
return False
1451-
for j in range(len_shape):
1452-
if src_shape[j] is None or src_shape[j] < 0:
1453-
continue
1454-
if other_shape[j] is None or other_shape[j] < 0:
1455-
continue
1456-
if src_shape[j] != other_shape[j]:
1445+
for (src_spec, desired_spec) in zip(src_input_specs,
1446+
desired_input_specs):
1447+
if isinstance(src_spec, paddle.static.InputSpec) or isinstance(
1448+
desired_spec, paddle.static.InputSpec):
1449+
if not _compatible_tensor_spec(src_spec, desired_spec):
1450+
return False
1451+
else:
1452+
if not _compatible_non_tensor_spec(src_spec, desired_spec):
14571453
return False
14581454

1459-
src_dtype = convert_dtype(src_input_specs[i].dtype)
1460-
other_dtype = convert_dtype(desired_input_specs[i].dtype)
1461-
if src_dtype != other_dtype:
1462-
return False
1455+
return True
1456+
1457+
1458+
def _compatible_tensor_spec(src_spec, desired_spec):
1459+
"""
1460+
Check whether two tensor type spec is compatible.
1461+
"""
1462+
for spec in [src_spec, desired_spec]:
1463+
if not isinstance(spec, paddle.static.InputSpec):
1464+
return False
1465+
src_shape = src_spec.shape
1466+
other_shape = desired_spec.shape
1467+
len_shape = len(src_shape)
1468+
if len_shape != len(other_shape):
1469+
return False
1470+
for j in range(len_shape):
1471+
if src_shape[j] is None or src_shape[j] < 0:
1472+
continue
1473+
if other_shape[j] is None or other_shape[j] < 0:
1474+
continue
1475+
if src_shape[j] != other_shape[j]:
1476+
return False
1477+
1478+
src_dtype = convert_dtype(src_spec.dtype)
1479+
other_dtype = convert_dtype(desired_spec.dtype)
1480+
if src_dtype != other_dtype:
1481+
return False
14631482

14641483
return True
14651484

14661485

1486+
def _compatible_non_tensor_spec(src_spec, desired_spec):
1487+
"""
1488+
Check whether two non-tensor type spec is compatible.
1489+
"""
1490+
1491+
def hash_value(spec):
1492+
try:
1493+
hash_val = make_hashable(spec)
1494+
except:
1495+
hash_val = None
1496+
return hash_val
1497+
1498+
src_hash_val = hash_value(src_spec)
1499+
desired_hash_val = hash_value(desired_spec)
1500+
1501+
if src_hash_val != desired_hash_val:
1502+
return False
1503+
else:
1504+
return True
1505+
1506+
14671507
def slice_is_num(slice_node):
14681508
# A slice_node.slice can be a:
14691509
# (1) ast.Index, which is a simple number such as [1], [-2]

python/paddle/fluid/dygraph/jit.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -403,8 +403,15 @@ def _get_input_var_names(inputs, input_spec):
403403
]
404404
if input_spec is None:
405405
# no prune
406-
result_list = input_var_names
407-
elif input_spec is not None and len(input_spec) == len(input_var_names):
406+
return input_var_names
407+
else:
408+
# fileter out non-tensor type spec infos.
409+
input_spec = [
410+
spec for spec in input_spec
411+
if isinstance(spec, paddle.static.InputSpec)
412+
]
413+
414+
if len(input_spec) == len(input_var_names):
408415
# no prune
409416
result_list = input_var_names
410417
# if input spec name not in input_var_names, only raise warning
@@ -530,8 +537,9 @@ def save(layer, path, input_spec=None, **configs):
530537
Args:
531538
layer (Layer|function): The Layer or function to be saved.
532539
path (str): The path prefix to save model. The format is ``dirname/file_prefix`` or ``file_prefix``.
533-
input_spec (list[InputSpec|Tensor]|tuple[InputSpec|Tensor], optional): Describes the input of the saved model's forward
534-
method, which can be described by InputSpec or example Tensor. If None, all input variables of
540+
input_spec (list or tuple[InputSpec|Tensor|Python built-in variable], optional): Describes the input of the saved model's forward
541+
method, which can be described by InputSpec or example Tensor. Moreover, we support to specify non-tensor type argument,
542+
such as int, float, string, or list/dict of them.If None, all input variables of
535543
the original Layer's forward method would be the inputs of the saved model. Default None.
536544
**configs (dict, optional): Other save configuration options for compatibility. We do not
537545
recommend using these configurations, they may be removed in the future. If not necessary,
@@ -698,9 +706,8 @@ def fun(inputs):
698706
inner_input_spec.append(
699707
paddle.static.InputSpec.from_tensor(var))
700708
else:
701-
raise TypeError(
702-
"The element in input_spec list should be 'Variable' or `paddle.static.InputSpec`, but received element's type is %s."
703-
% type(var))
709+
# NOTE(Aurelius84): Support non-Tensor type in `input_spec`.
710+
inner_input_spec.append(var)
704711

705712
# parse configs
706713
configs = _parse_save_configs(configs)
@@ -719,7 +726,7 @@ def fun(inputs):
719726
inner_input_spec)
720727
elif 'forward' == attr_func:
721728
# transform in jit.save, if input_spec is incomplete, declarative will throw error
722-
# inner_input_spec is list[InputSpec], it should be packed with same sturcture
729+
# inner_input_spec is list[InputSpec], it should be packed with same structure
723730
# as original input_spec here.
724731
if inner_input_spec:
725732
inner_input_spec = pack_sequence_as(input_spec,

python/paddle/fluid/tests/unittests/dygraph_to_static/test_function_spec.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,6 @@ def test_verify_input_spec(self):
3939
with self.assertRaises(TypeError):
4040
foo_spec = FunctionSpec(foo_func, input_spec=a_spec)
4141

42-
# each element of input_spec should be `InputSpec`
43-
with self.assertRaises(ValueError):
44-
foo_spec = FunctionSpec(foo_func, input_spec=[a_spec, 10])
45-
4642
foo_spec = FunctionSpec(foo_func, input_spec=[a_spec, b_spec])
4743
self.assertTrue(len(foo_spec.flat_input_spec) == 2)
4844

0 commit comments

Comments
 (0)