Skip to content

[BUG] Exception raised when saving a re-loaded model with a transformer block #878

@oliverholworthy

Description

@oliverholworthy

Bug description

Attempting to save a re-loaded model containing a transformer block. We get an exception about inputs to the model.

TF 2.9.2 - TypeError: Unable to serialize - **Schema**
    def test_clm_reload():
        model_dir = "/tmp/clm_model"
        reloaded_model = mm.Model.load(model_dir)
>       reloaded_model.save("/tmp/clm_model_2")

tests/unit/tf/transformers/test_block.py:320:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py:67: in error_handler
    raise e.with_traceback(filtered_tb) from None
/usr/lib/python3.8/json/encoder.py:199: in encode
    chunks = self.iterencode(o, _one_shot=True)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
self = <keras.saving.saved_model.json_utils.Encoder object at 0x7fdcc8ba68e0>
o = {'backend': 'tensorflow', 'batch_input_shape': None, 'class_name': 'merlin.models>Model', 'config': {'0': {'class_name...AgYBBgEGAQYBBvs=\n', {...}, None)}, 'dtype': 'float32', 'function_type': 'lambda', ...}, 'shared_object_id': 21}}, ...}, _one_shot = True

    def iterencode(self, o, _one_shot=False):
        """Encode the given object and yield each string
        representation as available.

        For example::

            for chunk in JSONEncoder().iterencode(bigobject):
                mysocket.write(chunk)

        """
        if self.check_circular:
            markers = {}
        else:
            markers = None
        if self.ensure_ascii:
            _encoder = encode_basestring_ascii
        else:
            _encoder = encode_basestring

        def floatstr(o, allow_nan=self.allow_nan,
                _repr=float.__repr__, _inf=INFINITY, _neginf=-INFINITY):
            # Check for specials.  Note that this type of test is processor
            # and/or platform-specific, so do tests which don't depend on the
            # internals.

            if o != o:
                text = 'NaN'
            elif o == _inf:
                text = 'Infinity'
            elif o == _neginf:
                text = '-Infinity'
            else:
                return _repr(o)

            if not allow_nan:
                raise ValueError(
                    "Out of range float values are not JSON compliant: " +
                    repr(o))

            return text


        if (_one_shot and c_make_encoder is not None
                and self.indent is None):
            _iterencode = c_make_encoder(
                markers, self.default, _encoder, self.indent,
                self.key_separator, self.item_separator, self.sort_keys,
                self.skipkeys, self.allow_nan)
        else:
            _iterencode = _make_iterencode(
                markers, self.default, _encoder, self.indent, floatstr,
                self.key_separator, self.item_separator, self.sort_keys,
                self.skipkeys, _one_shot)
>       return _iterencode(o, 0)
E       TypeError: Unable to serialize [{'name': 'item_id_seq', 'tags': {<Tags.ITEM_ID: 'item_id'>, <Tags.ID: 'id'>, <Tags.SEQUENCE: 'sequence'>, <Tags.CATEGORICAL: 'categorical'>, <Tags.ITEM: 'item'>}, 'properties': {'domain': {'min': 1, 'max': 51996, 'name': 'item_id_seq'}, 'value_count': {'min': 1, 'max': 4}}, 'dtype': dtype('int64'), 'is_list': True, 'is_ragged': True}, {'name': 'categories', 'tags': {<Tags.ITEM: 'item'>, <Tags.CATEGORICAL: 'categorical'>, <Tags.SEQUENCE: 'sequence'>, <Tags.LIST: 'list'>}, 'properties': {'domain': {'min': 1, 'max': 331, 'name': 'categories'}, 'value_count': {'min': 1, 'max': 4}}, 'dtype': dtype('int64'), 'is_list': True, 'is_ragged': True}] to JSON. Unrecognized type <class 'merlin.schema.schema.Schema'>.

/usr/lib/python3.8/json/encoder.py:257: TypeError
TF 2.10.0 - Could not find matching concrete function to call loaded from the SavedModel
args = ({'categories': tf.RaggedTensor(values=Tensor("args_0:0", shape=(None,), dtype=int64), row_splits=Tensor("args_0_1:0",...e=(None,), dtype=float32), row_splits=Tensor("args_0_7:0", shape=(None,), dtype=int32)), ...}, None, False, None, None), kwargs = {}
do_return = False, retval_ = <tensorflow.python.autograph.operators.variables.UndefinedReturnValue object at 0x7f37e74460a0>

    def tf___wrapped_model(*args, **kwargs):
        "A concrete tf.function that wraps the model's call function."
        with ag__.FunctionScope('_wrapped_model', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
            do_return = False
            retval_ = ag__.UndefinedReturnValue()
            (args, kwargs) = ag__.converted_call(ag__.ld(model)._call_spec.set_arg_value, ('training', False, ag__.ld(args), ag__.ld(kwargs)), dict(inputs_in_args=True), fscope)
            with ag__.ld(base_layer_utils).call_context().enter(ag__.ld(model), inputs=None, build_graph=False, training=False, saving=True):
>               outputs = ag__.converted_call(ag__.ld(model), tuple(ag__.ld(args)), dict(**ag__.ld(kwargs)), fscope)
E               ValueError: in user code:
E
E                   File "/home/oliverholworthy/anaconda3/envs/python-3.8-rapids-22.10/lib/python3.8/site-packages/keras/saving/saving_utils.py", line 147, in _wrapped_model  *
E                       outputs = model(*args, **kwargs)
E                   File "/home/oliverholworthy/anaconda3/envs/python-3.8-rapids-22.10/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler  **
E                       raise e.with_traceback(filtered_tb) from None
E
E                   ValueError: Exception encountered when calling layer "model" "                 f"(type merlin.models>Model).
E
E                   Could not find matching concrete function to call loaded from the SavedModel. Got:
E                     Positional arguments (5 total):
E                       * {'categories': tf.RaggedTensor(values=Tensor("inputs:0", shape=(None,), dtype=int64), row_splits=Tensor("inputs_1:0", shape=(None,), dtype=int32)),
E                    'event_hour_cos': tf.RaggedTensor(values=Tensor("inputs_2:0", shape=(None,), dtype=float32), row_splits=Tensor("inputs_3:0", shape=(None,), dtype=int32)),
E                    'event_hour_sin': tf.RaggedTensor(values=Tensor("inputs_4:0", shape=(None,), dtype=float32), row_splits=Tensor("inputs_5:0", shape=(None,), dtype=int32)),
E                    'event_weekday_cos': tf.RaggedTensor(values=Tensor("inputs_6:0", shape=(None,), dtype=float32), row_splits=Tensor("inputs_7:0", shape=(None,), dtype=int32)),
E                    'event_weekday_sin': tf.RaggedTensor(values=Tensor("inputs_8:0", shape=(None,), dtype=float32), row_splits=Tensor("inputs_9:0", shape=(None,), dtype=int32)),
E                    'item_age_days_norm': tf.RaggedTensor(values=Tensor("inputs_10:0", shape=(None,), dtype=float32), row_splits=Tensor("inputs_11:0", shape=(None,), dtype=int32)),
E                    'item_id_seq': tf.RaggedTensor(values=Tensor("inputs_12:0", shape=(None,), dtype=int64), row_splits=Tensor("inputs_13:0", shape=(None,), dtype=int32)),
E                    'test_user_id': <tf.Tensor 'inputs_14:0' shape=(None, 1) dtype=int64>,
E                    'user_age': <tf.Tensor 'inputs_15:0' shape=(None, 1) dtype=float32>,
E                    'user_country': <tf.Tensor 'inputs_16:0' shape=(None, 1) dtype=int64>}
E                       * None
E                       * False
E                       * None
E                       * None
E                     Keyword arguments: {}
E
E                    Expected these arguments to match one of the following 2 option(s):
E
E                   Option 1:
E                     Positional arguments (5 total):
E                       * {'categories': RaggedTensorSpec(TensorShape([None, None]), tf.int64, 1, tf.int32),
E                    'event_hour_cos': RaggedTensorSpec(TensorShape([None, None]), tf.float32, 1, tf.int32),
E                    'event_hour_sin': RaggedTensorSpec(TensorShape([None, None]), tf.float32, 1, tf.int32),
E                    'event_weekday_cos': RaggedTensorSpec(TensorShape([None, None]), tf.float32, 1, tf.int32),
E                    'event_weekday_sin': RaggedTensorSpec(TensorShape([None, None]), tf.float32, 1, tf.int32),
E                    'item_age_days_norm': RaggedTensorSpec(TensorShape([None, None]), tf.float32, 1, tf.int32),
E                    'item_id_seq': RaggedTensorSpec(TensorShape([None, None]), tf.int64, 1, tf.int32),
E                    'test_user_id': TensorSpec(shape=(None, 1), dtype=tf.int64, name='inputs/test_user_id'),
E                    'user_age': TensorSpec(shape=(None, 1), dtype=tf.float32, name='inputs/user_age'),
E                    'user_country': TensorSpec(shape=(None, 1), dtype=tf.int64, name='inputs/user_country')}
E                       * None
E                       * False
E                       * False
E                       * False
E                     Keyword arguments: {}
E
E                   Option 2:
E                     Positional arguments (5 total):
E                       * {'categories': RaggedTensorSpec(TensorShape([None, None]), tf.int64, 1, tf.int32),
E                    'event_hour_cos': RaggedTensorSpec(TensorShape([None, None]), tf.float32, 1, tf.int32),
E                    'event_hour_sin': RaggedTensorSpec(TensorShape([None, None]), tf.float32, 1, tf.int32),
E                    'event_weekday_cos': RaggedTensorSpec(TensorShape([None, None]), tf.float32, 1, tf.int32),
E                    'event_weekday_sin': RaggedTensorSpec(TensorShape([None, None]), tf.float32, 1, tf.int32),
E                    'item_age_days_norm': RaggedTensorSpec(TensorShape([None, None]), tf.float32, 1, tf.int32),
E                    'item_id_seq': RaggedTensorSpec(TensorShape([None, None]), tf.int64, 1, tf.int32),
E                    'test_user_id': TensorSpec(shape=(None, 1), dtype=tf.int64, name='inputs/test_user_id'),
E                    'user_age': TensorSpec(shape=(None, 1), dtype=tf.float32, name='inputs/user_age'),
E                    'user_country': TensorSpec(shape=(None, 1), dtype=tf.int64, name='inputs/user_country')}
E                       * None
E                       * True
E                       * False
E                       * False
E                     Keyword arguments: {}
E
E                   Call arguments received by layer "model" "                 f"(type merlin.models>Model):
E                     • args=({'item_age_days_norm': 'tf.RaggedTensor(values=Tensor("args_0_10:0", shape=(None,), dtype=float32), row_splits=Tensor("args_0_11:0", shape=(None,), dtype=int32))', 'item_id_seq': 'tf.RaggedTensor(values=Tensor("args_0_12:0", shape=(None,), dtype=int64), row_splits=Tensor("args_0_13:0", shape=(None,), dtype=int32))', 'event_hour_sin': 'tf.RaggedTensor(values=Tensor("args_0_4:0", shape=(None,), dtype=float32), row_splits=Tensor("args_0_5:0", shape=(None,), dtype=int32))', 'test_user_id': 'tf.Tensor(shape=(None, 1), dtype=int64)', 'categories': 'tf.RaggedTensor(values=Tensor("args_0:0", shape=(None,), dtype=int64), row_splits=Tensor("args_0_1:0", shape=(None,), dtype=int32))', 'event_weekday_sin': 'tf.RaggedTensor(values=Tensor("args_0_8:0", shape=(None,), dtype=float32), row_splits=Tensor("args_0_9:0", shape=(None,), dtype=int32))', 'event_weekday_cos': 'tf.RaggedTensor(values=Tensor("args_0_6:0", shape=(None,), dtype=float32), row_splits=Tensor("args_0_7:0", shape=(None,), dtype=int32))', 'user_country': 'tf.Tensor(shape=(None, 1), dtype=int64)', 'user_age': 'tf.Tensor(shape=(None, 1), dtype=float32)', 'event_hour_cos': 'tf.RaggedTensor(values=Tensor("args_0_2:0", shape=(None,), dtype=float32), row_splits=Tensor("args_0_3:0", shape=(None,), dtype=int32))'}, 'None', 'False', 'None', 'None')
E                     • kwargs=<class 'inspect._empty'>

/tmp/__autograph_generated_file024ppwfs.py:14: ValueError

Steps/Code to reproduce bug

  1. Run example test below

Example test

def test_clm(sequence_testing_data: Dataset):

    seq_schema = sequence_testing_data.schema.select_by_tag(Tags.SEQUENCE).select_by_tag(
        Tags.CATEGORICAL
    )
    target = sequence_testing_data.schema.select_by_tag(Tags.ITEM_ID).column_names[0]
    predict_next = mm.SequencePredictNext(schema=seq_schema, target=target)

    loader = Loader(sequence_testing_data, batch_size=8, shuffle=False, transform=predict_next)

    d_model = 48
    model = mm.Model(
        mm.InputBlockV2(
            seq_schema,
            categorical=mm.Embeddings(
                seq_schema.select_by_tag(Tags.CATEGORICAL), sequence_combiner=None
            ),
        ),
        mm.MLPBlock([d_model]),
        mm.GPT2Block(d_model=d_model, n_head=8, n_layer=2),
        mm.CategoricalOutput(
            seq_schema.select_by_name(target), default_loss="categorical_crossentropy"
        ),
    )

    model.compile()
    model.fit(loader)

    model_dir = "/tmp/clm_model"
    model.save(model_dir)

def test_clm_reload():
    model_dir = "/tmp/clm_model"
    reloaded_model = mm.Model.load(model_dir)
    reloaded_model.save("/tmp/clm_model_2")

Produces error:

  • Running test_clm_reload after test_clm.

Doesn't produce error:

  • Putting the contents of test_clm_reload into test_clm (Something to do with different state)
  • Replacing GPT2Block with mm.ListToDense()

Expected behavior

Able to save a re-loaded model without any errors. And result in the same saved model artifact as the one being loaded.

Enviroment details

  • Merlin version: 22.10
  • Python version: 3.8
  • Tensorflow version (GPU): [2.9.1+nv22.8, 2.9.2, 2.10.0]

Additional context

Metadata

Metadata

Assignees

Labels

P0bugSomething isn't working

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions