@@ -68,13 +68,18 @@ def data_layers(self):
68
68
get all data layer
69
69
:return:
70
70
"""
71
- data_layers = set ()
71
+ data_layers = dict ()
72
72
73
73
def find_data_layer (layer ):
74
74
if isinstance (layer , v2_layer .DataLayerV2 ):
75
- data_layers .add (layer )
76
- for parent_layer in layer .__parent_layers__ .values ():
77
- find_data_layer (parent_layer )
75
+ data_layers [layer .name ] = layer
76
+ if not isinstance (layer , collections .Sequence ):
77
+ for parent_layer in layer .__parent_layers__ .values ():
78
+ find_data_layer (parent_layer )
79
+ else :
80
+ for each_l in layer :
81
+ for parent_layer in each_l .__parent_layers__ .values ():
82
+ find_data_layer (parent_layer )
78
83
79
84
for layer in self .layers :
80
85
find_data_layer (layer )
@@ -86,8 +91,12 @@ def data_type(self):
86
91
get data_type from proto, such as:
87
92
[('image', dense_vector(768)), ('label', integer_value(10))]
88
93
"""
89
- return [(data_layer .name , data_layer .type )
90
- for data_layer in self .data_layers ()]
94
+
95
+ data_types_lists = []
96
+ for each in self .__model_config__ .input_layer_names :
97
+ data_layers = self .data_layers ()
98
+ data_types_lists .append ((each , data_layers [each ].type ))
99
+ return data_types_lists
91
100
92
101
93
102
def __check_layer_type__ (layer ):
0 commit comments