Skip to content

TimeSeriesTransformer summary network not deserializable #423

@vpratz

Description

@vpratz

Loading the network fails with the error message below (when adding the network to the summary network test suite).
I do not know how to fix this yet, I will open a draft PR with the added test and we can discuss how to resolve it there.

error_msgs = {139725321209360: (<Dense name=dense_47, built=True>, ValueError("Layer 'dense_47' expected 2 variables, but received 0 variables during loading. Expected: ['kernel', 'bias']"))}
warn_only = False                                                                                                      
                                                                                                                       
    def _raise_loading_failure(error_msgs, warn_only=False):                                                           
        first_key = list(error_msgs.keys())[0]                                                                         
        ex_saveable, ex_error = error_msgs[first_key]                                                                  
        msg = (                                                                                                        
            f"A total of {len(error_msgs)} objects could not "                                                         
            "be loaded. Example error message for "                                                                    
            f"object {ex_saveable}:\n\n"                                                                               
            f"{ex_error}\n\n"                                                                                          
            "List of objects that could not be loaded:\n"                                                              
            f"{[x[0] for x in error_msgs.values()]}"                                                                   
        )                                                                                                              
        if warn_only:                                                                                                  
            warnings.warn(msg)                                                                                         
        else:                                                                                                          
>           raise ValueError(msg)                                                                                      
E           ValueError: A total of 1 objects could not be loaded. Example error message for object <Dense name=dense_47, built=True>:
E                                                                                                                      
E           Layer 'dense_47' expected 2 variables, but received 0 variables during loading. Expected: ['kernel', 'bias']
E                                                                                                                      
E           List of objects that could not be loaded:                                                                  
E           [<Dense name=dense_47, built=True>]                                                                        
                                                                                                                       
/data/Programming/.mamba/envs/bf2/lib/python3.11/site-packages/keras/src/saving/saving_lib.py:631: ValueError          
_ test_save_and_load[summary_dim=4-batch_size=3-set_size=3-feature_size=2-summary_network='time_series_transformer'] __
                                                                                                                       
tmp_path = PosixPath('/tmp/pytest-of-valentin/pytest-4/test_save_and_load_summary_dim27')                              
summary_network = <TimeSeriesTransformer name=time_series_transformer_45, built=True>                                  
random_set = <tf.Tensor: shape=(3, 3, 2), dtype=float32, numpy=                                                        
array([[[-1.0457249 , -0.03069962],                                                                                    
        [-0.37220615, -1.355771... [[ 0.47269952, -0.5362306 ],                                                        
        [ 0.5662243 ,  0.54387766],                                                                                    
        [-1.3460127 , -0.19182666]]], dtype=float32)>                                                                  
                                                                                                                       
    def test_save_and_load(tmp_path, summary_network, random_set):                                                     
        if summary_network is None:                                                                                    
            pytest.skip(reason="Nothing to do, because there is no summary network.")                                  
                                                                                                                       
        summary_network.build(keras.ops.shape(random_set))                                                             
                                                                                                                       
        keras.saving.save_model(summary_network, tmp_path / "model.keras")                                             
>       loaded = keras.saving.load_model(tmp_path / "model.keras")                                                     
                                                                                                                       
tests/test_networks/test_summary_networks.py:77:                                                                       
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
/data/Programming/.mamba/envs/bf2/lib/python3.11/site-packages/keras/src/saving/saving_api.py:189: in load_model       
    return saving_lib.load_model(                                                                                      
/data/Programming/.mamba/envs/bf2/lib/python3.11/site-packages/keras/src/saving/saving_lib.py:367: in load_model       
    return _load_model_from_fileobj(                                                                                   
/data/Programming/.mamba/envs/bf2/lib/python3.11/site-packages/keras/src/saving/saving_lib.py:509: in _load_model_from_fileobj
    _raise_loading_failure(error_msgs)                                                                                 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions