@@ -449,7 +449,7 @@ def test_quantized_layer_norm_per_tensor(
449449 ), # expected_output: [1+2, 2+3, 3+4] / 0.5 = [6, 10, 14]
450450 memory_format ,
451451 )
452- for memory_format in [torch .contiguous_format ]
452+ for memory_format in [torch .contiguous_format , torch . channels_last ]
453453 ],
454454 # Test case 5: Multiple output channels
455455 * [
@@ -686,10 +686,13 @@ def test_quantized_conv_per_tensor(
686686 ) -> None :
687687 assert memory_format in [torch .contiguous_format , torch .channels_last ]
688688
689- if len (input_tensor .shape ) == 3 and memory_format == torch .channels_last :
690- self .fail ("Channels last format is not supported for 3D input tensors" )
691-
692- input_tensor = input_tensor .to (memory_format = memory_format )
689+ if memory_format == torch .channels_last :
690+ if input_tensor .ndim == 3 :
691+ input_tensor = input_tensor .movedim (1 , - 1 )
692+ weight = weight .movedim (1 , - 1 )
693+ else :
694+ input_tensor = input_tensor .movedim (- 3 , - 1 )
695+ weight = weight .movedim (- 3 , - 1 )
693696
694697 convs = [
695698 (
@@ -701,7 +704,7 @@ def test_quantized_conv_per_tensor(
701704
702705 optimized_convs = []
703706 if input_tensor .dtype == torch .int8 and weight .dtype == torch .int8 :
704- if input_tensor . is_contiguous ( memory_format = torch .contiguous_format ) :
707+ if memory_format == torch .contiguous_format :
705708 optimized_convs = [
706709 torch .ops .cadence .quantized_conv_nchw_asym8sxsym8s_asym8s .per_tensor ,
707710 torch .ops .cadence .quantized_conv_nchw_dilated_asym8sxsym8s_asym8s .per_tensor ,
@@ -715,7 +718,7 @@ def test_quantized_conv_per_tensor(
715718 torch .ops .cadence .quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s .per_tensor ,
716719 ]
717720 elif input_tensor .dtype == torch .uint8 and weight .dtype == torch .uint8 :
718- if input_tensor . is_contiguous ( memory_format = torch .contiguous_format ) :
721+ if memory_format == torch .contiguous_format :
719722 optimized_convs = [
720723 torch .ops .cadence .quantized_conv_nchw_asym8uxsym8u_asym8u .per_tensor ,
721724 torch .ops .cadence .quantized_conv_nchw_dilated_asym8uxsym8u_asym8u .per_tensor ,
@@ -746,7 +749,13 @@ def test_quantized_conv_per_tensor(
746749 output_zero_point ,
747750 out_multiplier ,
748751 out_shift ,
749- ).to (memory_format = torch .contiguous_format )
752+ )
753+
754+ if memory_format == torch .channels_last :
755+ if input_tensor .ndim == 3 :
756+ output = output .movedim (- 1 , 1 )
757+ else :
758+ output = output .movedim (- 1 , - 3 )
750759
751760 # Verify output properties
752761 self .assertEqual (output .dtype , dtype , f"Output dtype should be { dtype } " )
0 commit comments