@@ -296,3 +296,78 @@ def test_cat__force_delegate():
296296 graph = quantized_program .graph , ops = [exir_ops .edge .aten .cat .default ]
297297 )
298298 assert any ("lowered_module" in node .name for node in quantized_program .graph .nodes )
299+
300+
301+ def test_cat__format_specific_support__formatless (mocker ):
302+ # The last dim will end up being the channels, as the format is `formatless`.
303+ # Only the last dim satisfies the Neutron requirements for the channels.
304+ input_shape = (3 , 3 , 3 , 8 )
305+ num_inputs = 2
306+ dim = 2
307+
308+ input_shapes = [input_shape ] * num_inputs
309+
310+ converter_spy = mocker .spy (EdgeProgramToIRConverter , "convert_program" )
311+
312+ quantized_program = to_quantized_edge_program (
313+ CatModule (dim ), input_shapes
314+ ).exported_program ()
315+
316+ # Make sure the `Cat` was delegated.
317+ assert not graph_contains_any_of_ops (
318+ graph = quantized_program .graph , ops = [exir_ops .edge .aten .cat .default ]
319+ )
320+ assert any ("lowered_module" in node .name for node in quantized_program .graph .nodes )
321+
322+ tflite_flatbuffers_model , io_formats = converter_spy .spy_return
323+ exported_program : ExportedProgram = converter_spy .call_args .args [1 ]
324+ input_data = {
325+ i : (np .random .random (shape ) * 50 ).astype (np .int8 )
326+ for i , shape in enumerate (input_shapes )
327+ }
328+ convert_run_compare (
329+ exported_program ,
330+ tfl_model = tflite_flatbuffers_model ,
331+ input_data = input_data ,
332+ atol = 1 ,
333+ )
334+
335+
336+ def test_cat__format_specific_support__channels_first (mocker ):
337+ # The second dim will end up being the channels, as the format is `formatless`.
338+ # Only the second dim satisfies the Neutron requirements for the channels.
339+ input_shape = (3 , 8 , 3 , 3 )
340+ num_inputs = 2
341+ dim = 2
342+
343+ input_shapes = [input_shape ] * num_inputs
344+
345+ converter_spy = mocker .spy (EdgeProgramToIRConverter , "convert_program" )
346+
347+ channels = (
348+ sum (shape [1 ] for shape in input_shapes ) if dim in [1 , - 3 ] else input_shape [1 ]
349+ )
350+ quantized_program = to_quantized_edge_program (
351+ CatConvModule (dim , channels ), input_shapes
352+ ).exported_program ()
353+
354+ # Make sure the `Cat` was delegated.
355+ assert not graph_contains_any_of_ops (
356+ graph = quantized_program .graph , ops = [exir_ops .edge .aten .cat .default ]
357+ )
358+ assert any ("lowered_module" in node .name for node in quantized_program .graph .nodes )
359+
360+ tflite_flatbuffers_model , io_formats = converter_spy .spy_return
361+ exported_program : ExportedProgram = converter_spy .call_args .args [1 ]
362+ input_data = {
363+ i : (np .random .random (shape ) * 50 ).astype (np .int8 )
364+ for i , shape in enumerate (input_shapes )
365+ }
366+ convert_run_compare (
367+ exported_program ,
368+ tfl_model = tflite_flatbuffers_model ,
369+ input_data = input_data ,
370+ tflite_input_preprocess = ToNHWCPreprocess (),
371+ tflite_output_preprocess = ToNCHWPreprocess (),
372+ atol = 1 ,
373+ )
0 commit comments