@@ -876,6 +876,35 @@ def conv_2d_nhwc_fhwc_q(
876876 ) * (TypeFn .cast_signed (U , K [D .f , D .kh , D .kw , D .c ]) - TypeFn .cast_signed (U , KZp ))
877877
878878
879+ @linalg_structured_op
880+ def conv_2d_nchw_fchw_q (
881+ I = TensorDef (T1 , S .N , S .C , S .OH * S .SH + S .KH * S .DH , S .OW * S .SW + S .KW * S .DW ),
882+ K = TensorDef (T2 , S .F , S .C , S .KH , S .KW ),
883+ IZp = ScalarDef (I32 ),
884+ KZp = ScalarDef (I32 ),
885+ O = TensorDef (U , S .N , S .F , S .OH , S .OW , output = True ),
886+ strides = IndexAttrDef (S .SH , S .SW , default = [1 , 1 ]),
887+ dilations = IndexAttrDef (S .DH , S .DW , default = [1 , 1 ]),
888+ ):
889+ """Performs 2-D convolution with zero point offsets.
890+
891+ Layout:
892+ * Input: NCHW.
893+ * Kernel: FCHW.
894+
895+ Numeric casting is performed on the operands to the inner multiply, promoting
896+ them to the same data type as the accumulator/output. This includes the zero
897+ point offsets common to quantized operations.
898+ """
899+ implements (ConvolutionOpInterface )
900+ domain (D .n , D .f , D .oh , D .ow , D .c , D .kh , D .kw )
901+ O [D .n , D .f , D .oh , D .ow ] += (
902+ TypeFn .cast_signed (
903+ U , I [D .n , D .c , D .oh * S .SH + D .kh * S .DH , D .ow * S .SW + D .kw * S .DW ]
904+ )
905+ - TypeFn .cast_signed (U , IZp )
906+ ) * (TypeFn .cast_signed (U , K [D .f , D .c , D .kh , D .kw ]) - TypeFn .cast_signed (U , KZp ))
907+
879908@linalg_structured_op
880909def conv_2d_nchw_fchw (
881910 I = TensorDef (T1 , S .N , S .C , S .OH * S .SH + S .KH * S .DH , S .OW * S .SW + S .KW * S .DW ),
0 commit comments