@@ -337,15 +337,16 @@ def _check_per_channel_group_params(
337337        # For now group quantization is only supported for 4b weights 
338338        assert  quant_params .is_qc4w , "Only 4b group quantization is supported" 
339339
340-     def  define_tensor (
340+     def  define_tensor (   # noqa: C901 
341341        self ,
342342        tensor : torch .fx .Node ,
343343        xnn_graph : XNNGraph ,
344344        vals_to_ids : Dict [torch .fx .Node , int ],
345345        convert_to_nhwc : bool  =  False ,
346-         swap_nc_for_depthwise_weights : bool  =  False ,
346+         swap_in_out_for_weights : bool  =  False ,
347347        quant_params : Optional [QuantParams ] =  None ,
348348        fp32_static_weights : bool  =  False ,
349+         groups : int  =  1 ,
349350    ) ->  None :
350351        """ 
351352        Defines an tensor value into the XNNGraph 
@@ -357,16 +358,21 @@ def define_tensor(
357358                        their corresponding ids in XNNGraph 
358359            convert_to_nhwc: bool to indicate whether tensor shape should be permuted to 
359360                        reflect the nhwc memory format. 
360-             swap_nc_for_depthwise_weights : bool to indicate whether tensor shape 
361-                         should be permuted such that the N and C dimensions are  
362-                         swapped , which should be used for depthwise convolution 
361+             swap_in_out_for_weights : bool to indicate whether tensor shape should be  
362+                         permuted and reshape from (inc, oc/groups, height, width) to (oc, inc/groups, height, width)  
363+                         , which should be used for depthwise/transpose  convolution 
363364                        weights. This is only valid for tensors which hold 
364365                        constant data. If used along with convert_to_nhwc, this 
365366                        swap will happen before converting to nhwc. 
366367            quant_params: Quantization meta data for this tensor, None if it is not quantized 
367368            fp32_static_weights: XNN_FLAG_FP32_STATIC_WEIGHTS for fp16 conv 
369+             groups: number of groups for swap_in_out_for_weights 
368370        """ 
369371
372+         assert  (
373+             swap_in_out_for_weights  or  groups  ==  1 
374+         ), "groups is option for swap_in_out_for_weights" 
375+ 
370376        if  tensor  in  vals_to_ids :
371377            return 
372378
@@ -394,15 +400,16 @@ def define_tensor(
394400            xnn_graph ,
395401            vals_to_ids ,
396402            convert_to_nhwc ,
397-             swap_nc_for_depthwise_weights ,
403+             swap_in_out_for_weights ,
398404            quant_params ,
399405            fp32_static_weights ,
406+             groups ,
400407        )
401408
402409        # convert tensor shape must reflect memory format, default is contiguous, so 
403410        # only permute shape if we are converting the tensor to nhwc format 
404-         if  swap_nc_for_depthwise_weights :
405-             dims  =  [dims [1 ], dims [0 ]] +  dims [2 :]
411+         if  swap_in_out_for_weights :
412+             dims  =  [dims [1 ]  *   groups , dims [0 ]  //   groups ] +  dims [2 :]
406413        if  convert_to_nhwc :
407414            check_or_raise (len (dims ) ==  4 , "Converting to nhwc requires 4d tensor" )
408415            dims  =  [dims [i ] for  i  in  PERM_NCHW_TO_NHWC ]
@@ -422,16 +429,16 @@ def define_tensor(
422429        )
423430
424431        # Override the quant params axis since we have 
425-         # updated the weights for depthwise, with that the out_channels dim 
432+         # updated the weights for depthwise/ transposed_conv2d , with that the out_channels dim 
426433        # will be dims[3] instead of dims[0]. Let's update the per_channel 
427434        # quant axis to match the new weight tensor before serializing 
428-         if  swap_nc_for_depthwise_weights  and  (
429-             quant_params  and  quant_params .per_channel 
430-         ):
435+         if  swap_in_out_for_weights  and  (quant_params  and  quant_params .per_channel ):
431436            if  quant_params .axis  ==  0 :
432437                quant_params .axis  =  len (dims ) -  1 
438+             elif  quant_params .axis  ==  1 :
439+                 quant_params .axis  =  0 
433440            else :
434-                 assert  f"Unsupported weight per channel quantization axis for depthwise conv2d: { quant_params .axis }  , expecting 0." 
441+                 assert  f"Unsupported weight per channel quantization axis for depthwise conv2d / conv_transpose2d  : { quant_params .axis }  , expecting 0 / 1 ." 
435442
436443        # Serialize tensor value 
437444        ser_val  =  (
@@ -492,9 +499,10 @@ def get_serialized_buffer_index(
492499        xnn_graph : XNNGraph ,
493500        vals_to_ids : Dict [torch .fx .Node , int ],
494501        convert_to_nhwc : bool ,
495-         swap_nc_for_depthwise_weights : bool ,
502+         swap_in_out_for_weights : bool ,
496503        quant_params : Optional [QuantParams ],
497504        fp32_static_weights : bool  =  False ,
505+         groups : int  =  1 ,
498506    ) ->  int :
499507        """ 
500508        If tensor holds some constant data, serialize it and return the 
@@ -507,24 +515,30 @@ def get_serialized_buffer_index(
507515                        their corresponding ids in XNNGraph 
508516            convert_to_nhwc: bool to indicate whether tensor shape should be permuted to 
509517                        reflect the nhwc memory format. 
510-             swap_nc_for_depthwise_weights : bool to indicate whether tensor shape 
511-                         should be permuted such that the N and C dimensions are  
512-                         swapped , which should be used for depthwise convolution 
518+             swap_in_out_for_weights : bool to indicate whether tensor shape should be  
519+                         permuted and reshape from (inc, oc/groups, height, width) to (oc, inc/groups, height, width)  
520+                         , which should be used for depthwise/transpose  convolution 
513521                        weights. This is only valid for tensors which hold 
514522                        constant data. If used along with convert_to_nhwc, this 
515523                        swap will happen before converting to nhwc. 
516524            quant_params: Quantization meta data for this tensor, None if it is not quantize 
517525            fp32_static_weights: bool to indicate whether tensor is fp32 static weights 
526+             groups: groups for swap_in_out_for_weights 
518527
519528        Returns: 
520529            buffer_idx: idx of the serialized data. 0 If not associated constant 
521530                        data 
522531        """ 
532+ 
533+         assert  (
534+             swap_in_out_for_weights  or  groups  ==  1 
535+         ), "groups is option for swap_in_out_for_weights" 
536+ 
523537        # The get_attr node is the input to quant_params. 
524538        get_attr_node  =  tensor  if  quant_params  is  None  else  quant_params .q_input 
525539        if  not  is_param_node (self .exported_program , get_attr_node ):
526540            check_or_raise (
527-                 not  swap_nc_for_depthwise_weights ,
541+                 not  swap_in_out_for_weights ,
528542                "Swapping N and C dimensions is only valid for constant data tensors" ,
529543            )
530544            return  0 
@@ -541,9 +555,16 @@ def get_serialized_buffer_index(
541555            # ensure that the const is fp32 
542556            const_val  =  const_val .to (dtype = torch .float32 ).contiguous ()
543557
544-         if  swap_nc_for_depthwise_weights :
545-             const_val  =  const_val .permute (
546-                 dims = ((1 , 0 ) +  tuple (range (2 , const_val .dim ())))
558+         if  swap_in_out_for_weights :
559+             # Permute and reshape the tensor from (inc, oc/groups, height, width) to (oc, inc/groups, height, width) 
560+             # which should be used for depthwise/transpose convolution weights for XNNPACK 
561+             shape  =  const_val .shape 
562+             const_val  =  const_val .reshape (
563+                 (groups , const_val .shape [0 ] //  groups ) +  const_val .shape [1 :]
564+             )
565+             const_val  =  const_val .permute ((0 , 2 , 1 ) +  tuple (range (3 , const_val .dim ())))
566+             const_val  =  const_val .reshape (
567+                 (shape [1 ] *  groups , shape [0 ] //  groups ) +  shape [2 :]
547568            ).contiguous ()
548569
549570        if  convert_to_nhwc :
0 commit comments