@@ -52,15 +52,18 @@ def save_tflite_model(model, inp, name):
5252 converter = tf .lite .TFLiteConverter .from_concrete_functions ([func ])
5353 tflite_model = converter .convert ()
5454
55- interpreter = tf .lite .Interpreter (model_content = tflite_model )
56-
5755 with open (f'{ name } .tflite' , 'wb' ) as f :
5856 f .write (tflite_model )
5957
6058 out = model (inp )
59+ out = np .array (out )
60+
61+ if len (inp .shape ) == 4 :
62+ inp = inp .transpose (0 , 3 , 1 , 2 )
63+ out = out .transpose (0 , 3 , 1 , 2 )
6164
62- np .save (f'{ name } _inp.npy' , inp . transpose ( 0 , 3 , 1 , 2 ) )
63- np .save (f'{ name } _out_Identity.npy' , np . array ( out ). transpose ( 0 , 3 , 1 , 2 ) )
65+ np .save (f'{ name } _inp.npy' , inp )
66+ np .save (f'{ name } _out_Identity.npy' , out )
6467
6568
6669@tf .function (input_signature = [tf .TensorSpec (shape = [1 , 3 , 3 , 1 ], dtype = tf .float32 )])
@@ -75,3 +78,27 @@ def replicate_by_pack(x):
7578inp = np .random .standard_normal ((1 , 3 , 3 , 1 )).astype (np .float32 )
7679save_tflite_model (replicate_by_pack , inp , 'replicate_by_pack' )
7780
81+ @tf .function (input_signature = [tf .TensorSpec (shape = [1 , 3 ], dtype = tf .float32 )])
82+ def split (x ):
83+ splitted = tf .split (
84+ x , 3 , axis = - 1 , num = None , name = 'split'
85+ )
86+ return tf .concat ((splitted [2 ], splitted [1 ], splitted [0 ]), axis = - 1 )
87+
88+ inp = np .random .standard_normal ((1 , 3 )).astype (np .float32 )
89+ save_tflite_model (split , inp , 'split' )
90+
91+
92+ fully_connected = tf .keras .models .Sequential ([
93+ tf .keras .layers .Dense (3 ),
94+ tf .keras .layers .ReLU (),
95+ tf .keras .layers .Softmax (),
96+ ])
97+
98+ fully_connected = tf .function (
99+ fully_connected .call ,
100+ input_signature = [tf .TensorSpec ((1 ,2 ), tf .float32 )],
101+ )
102+
103+ inp = np .random .standard_normal ((1 , 2 )).astype (np .float32 )
104+ save_tflite_model (fully_connected , inp , 'fully_connected' )
0 commit comments