@@ -319,6 +319,27 @@ def test_cat__same_shapes_converter_padding_last_dimension():
319319 assert any ("lowered_module" in node .name for node in quantized_program .graph .nodes )
320320
321321
322+ def test_cat__same_shapes__channels_first__padding_channels ():
323+ target = "imxrt700"
324+
325+ # The Converter is capable of padding the last dimension of `cat` with the same input shapes.
326+ input_shape = (1 , 2 , 3 , 4 )
327+
328+ quantized_program = to_quantized_edge_program (
329+ CatConvModule (1 ),
330+ [input_shape , input_shape ],
331+ target = target ,
332+ neutron_converter_flavor = "SDK_25_09" ,
333+ custom_delegation_options = CustomDelegationOptions (),
334+ ).exported_program ()
335+
336+ # Make sure the `Cat` was delegated.
337+ assert not graph_contains_any_of_ops (
338+ graph = quantized_program .graph , ops = [exir_ops .edge .aten .cat .default ]
339+ )
340+ assert any ("lowered_module" in node .name for node in quantized_program .graph .nodes )
341+
342+
322343def test_cat__same_shapes_converter_padding_middle_dimension ():
323344 target = "imxrt700"
324345
@@ -339,3 +360,78 @@ def test_cat__same_shapes_converter_padding_middle_dimension():
339360 assert not any (
340361 "lowered_module" in node .name for node in quantized_program .graph .nodes
341362 )
363+
364+
365+ def test_cat__format_specific_support__formatless (mocker ):
366+ # The last dim will end up being the channels, as the format is `formatless`.
367+ # Only the last dim satisfies the Neutron requirements for the channels.
368+ input_shape = (3 , 3 , 3 , 8 )
369+ num_inputs = 2
370+ dim = 2
371+
372+ input_shapes = [input_shape ] * num_inputs
373+
374+ converter_spy = mocker .spy (EdgeProgramToIRConverter , "convert_program" )
375+
376+ quantized_program = to_quantized_edge_program (
377+ CatModule (dim ), input_shapes
378+ ).exported_program ()
379+
380+ # Make sure the `Cat` was delegated.
381+ assert not graph_contains_any_of_ops (
382+ graph = quantized_program .graph , ops = [exir_ops .edge .aten .cat .default ]
383+ )
384+ assert any ("lowered_module" in node .name for node in quantized_program .graph .nodes )
385+
386+ tflite_flatbuffers_model , io_formats = converter_spy .spy_return
387+ exported_program : ExportedProgram = converter_spy .call_args .args [1 ]
388+ input_data = {
389+ i : (np .random .random (shape ) * 50 ).astype (np .int8 )
390+ for i , shape in enumerate (input_shapes )
391+ }
392+ convert_run_compare (
393+ exported_program ,
394+ tfl_model = tflite_flatbuffers_model ,
395+ input_data = input_data ,
396+ atol = 1 ,
397+ )
398+
399+
400+ def test_cat__format_specific_support__channels_first (mocker ):
401+ # The second dim will end up being the channels, as the format is `formatless`.
402+ # Only the second dim satisfies the Neutron requirements for the channels.
403+ input_shape = (3 , 8 , 3 , 3 )
404+ num_inputs = 2
405+ dim = 2
406+
407+ input_shapes = [input_shape ] * num_inputs
408+
409+ converter_spy = mocker .spy (EdgeProgramToIRConverter , "convert_program" )
410+
411+ channels = (
412+ sum (shape [1 ] for shape in input_shapes ) if dim in [1 , - 3 ] else input_shape [1 ]
413+ )
414+ quantized_program = to_quantized_edge_program (
415+ CatConvModule (dim , channels ), input_shapes
416+ ).exported_program ()
417+
418+ # Make sure the `Cat` was delegated.
419+ assert not graph_contains_any_of_ops (
420+ graph = quantized_program .graph , ops = [exir_ops .edge .aten .cat .default ]
421+ )
422+ assert any ("lowered_module" in node .name for node in quantized_program .graph .nodes )
423+
424+ tflite_flatbuffers_model , io_formats = converter_spy .spy_return
425+ exported_program : ExportedProgram = converter_spy .call_args .args [1 ]
426+ input_data = {
427+ i : (np .random .random (shape ) * 50 ).astype (np .int8 )
428+ for i , shape in enumerate (input_shapes )
429+ }
430+ convert_run_compare (
431+ exported_program ,
432+ tfl_model = tflite_flatbuffers_model ,
433+ input_data = input_data ,
434+ tflite_input_preprocess = ToNHWCPreprocess (),
435+ tflite_output_preprocess = ToNCHWPreprocess (),
436+ atol = 1 ,
437+ )
0 commit comments