Skip to content
This repository was archived by the owner on Jul 1, 2024. It is now read-only.

Keras-MXNet directly crash when the hidden layer size of LSTM is 1 #272

@maybeLee

Description

@maybeLee

When I am using Keras to construct a one LSTM layer DL model with mxnet backend, Keras-MXNet will directly crash with the following traceback:

MXNET:
Traceback (most recent call last):
  File "exp1/job1/scripts/generation/script_prediction.py", line 123, in <module>
    _get_prediction(bk=bk, x=x[:1500], model_path=flags.model_path, batch_size=batch_size)
  File "exp1/job1/scripts/generation/script_prediction.py", line 35, in _get_prediction
    model = keras.models.load_model(model_path, custom_objects=custom_objects())
  File "lib/python3.6/site-packages/keras/engine/saving.py", line 496, in load_model
    model = _deserialize_model(f, custom_objects, compile)
  File "lib/python3.6/site-packages/keras/engine/saving.py", line 302, in _deserialize_model
    model = model_from_config(model_config, custom_objects=custom_objects)
  File "lib/python3.6/site-packages/keras/engine/saving.py", line 535, in model_from_config
    return deserialize(config, custom_objects=custom_objects)
  File "lib/python3.6/site-packages/keras/layers/__init__.py", line 55, in deserialize
    printable_module_name='layer')
  File "lib/python3.6/site-packages/keras/utils/generic_utils.py", line 145, in deserialize_keras_object
    list(custom_objects.items())))
  File "lib/python3.6/site-packages/keras/engine/sequential.py", line 301, in from_config
    model.add(layer)
  File "lib/python3.6/site-packages/keras/engine/sequential.py", line 165, in add
    layer(x)
  File "lib/python3.6/site-packages/keras/layers/recurrent.py", line 532, in __call__
    return super(RNN, self).__call__(inputs, **kwargs)
  File "lib/python3.6/site-packages/keras/engine/base_layer.py", line 444, in __call__
    self.build(unpack_singleton(input_shapes))
  File "lib/python3.6/site-packages/keras/layers/recurrent.py", line 493, in build
    self.cell.build(step_input_shape)
  File "lib/python3.6/site-packages/keras/layers/recurrent.py", line 1901, in build
    constraint=self.bias_constraint)
  File "lib/python3.6/site-packages/keras/legacy/interfaces.py", line 91, in wrapper
    return func(*args, **kwargs)
  File "lib/python3.6/site-packages/keras/engine/base_layer.py", line 255, in add_weight
    weight = K.variable(initializer(shape),
  File "lib/python3.6/site-packages/keras/layers/recurrent.py", line 1893, in bias_initializer
    self.bias_initializer((self.units * 2,), *args, **kwargs),
  File "lib/python3.6/site-packages/keras/backend/mxnet_backend.py", line 94, in func_wrapper
    train_symbol = func(*args, **kwargs)
  File "lib/python3.6/site-packages/keras/backend/mxnet_backend.py", line 2060, in concatenate
    symbols = [t.symbol for t in tensors]
  File "lib/python3.6/site-packages/keras/backend/mxnet_backend.py", line 2060, in <listcomp>
    symbols = [t.symbol for t in tensors]
AttributeError: 'NDArray' object has no attribute 'symbol'

This crash bug happens when I change the hidden size of LSTM layer to be 1. If the hidden size is set to a value other than 1, this bug will not appear.

you can reproduce this bug by simply running the following script:

import os
import argparse
import sys
import warnings
parse = argparse.ArgumentParser()
parse.add_argument("--backend", type=str,default="mxnet", help="the name of backend")
flags, _ = parse.parse_known_args(sys.argv[1:])
os.environ["KERAS_BACKEND"]=flags.backend
import keras
from keras import initializers, layers
import numpy as np
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)
model = keras.models.Sequential()
model.add(layers.LSTM(1))
model.build((None, 49,1))
x = np.random.rand(10,49,1)
pred = model.predict(x)

The tested version of MXNet is 1.5.1 and the version of keras-mxnet is 2.2.4.2

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions