@@ -337,7 +337,7 @@ def _check_per_channel_group_params(
337337 # For now group quantization is only supported for 4b weights
338338 assert quant_params .is_qc4w , "Only 4b group quantization is supported"
339339
340- def define_tensor (
340+ def define_tensor ( # noqa: C901
341341 self ,
342342 tensor : torch .fx .Node ,
343343 xnn_graph : XNNGraph ,
@@ -346,6 +346,8 @@ def define_tensor(
346346 swap_nc_for_depthwise_weights : bool = False ,
347347 quant_params : Optional [QuantParams ] = None ,
348348 fp32_static_weights : bool = False ,
349+ swap_in_out_for_transpose_weights : bool = False ,
350+ groups : int = 1 ,
349351 ) -> None :
350352 """
351353 Defines an tensor value into the XNNGraph
@@ -365,6 +367,9 @@ def define_tensor(
365367 swap will happen before converting to nhwc.
366368 quant_params: Quantization meta data for this tensor, None if it is not quantized
367369 fp32_static_weights: XNN_FLAG_FP32_STATIC_WEIGHTS for fp16 conv
370+ swap_in_out_for_transpose_weights: bool to indicate whether tensor shape should be
371+ permuted and reshape from (inc, oc/groups, height, width) to (oc, inc/groups, height, width)
372+ groups: number of groups for swap_in_out_for_transpose_weights
368373 """
369374
370375 if tensor in vals_to_ids :
@@ -397,12 +402,16 @@ def define_tensor(
397402 swap_nc_for_depthwise_weights ,
398403 quant_params ,
399404 fp32_static_weights ,
405+ swap_in_out_for_transpose_weights ,
406+ groups ,
400407 )
401408
402409 # convert tensor shape must reflect memory format, default is contiguous, so
403410 # only permute shape if we are converting the tensor to nhwc format
404411 if swap_nc_for_depthwise_weights :
405412 dims = [dims [1 ], dims [0 ]] + dims [2 :]
413+ if swap_in_out_for_transpose_weights :
414+ dims = [dims [1 ] * groups , dims [0 ] // groups ] + dims [2 :]
406415 if convert_to_nhwc :
407416 check_or_raise (len (dims ) == 4 , "Converting to nhwc requires 4d tensor" )
408417 dims = [dims [i ] for i in PERM_NCHW_TO_NHWC ]
@@ -433,6 +442,14 @@ def define_tensor(
433442 else :
434443 assert f"Unsupported weight per channel quantization axis for depthwise conv2d: { quant_params .axis } , expecting 0."
435444
445+ if swap_in_out_for_transpose_weights and (
446+ quant_params and quant_params .per_channel
447+ ):
448+ if quant_params .axis == 0 :
449+ quant_params .axis = len (dims ) - 1
450+ else :
451+ assert f"Unsupported weight per channel quantization axis for conv_transpose2d: { quant_params .axis } , expecting 0."
452+
436453 # Serialize tensor value
437454 ser_val = (
438455 XValue (xvalue_union = tvalue )
@@ -495,6 +512,8 @@ def get_serialized_buffer_index(
495512 swap_nc_for_depthwise_weights : bool ,
496513 quant_params : Optional [QuantParams ],
497514 fp32_static_weights : bool = False ,
515+ swap_in_out_for_transpose_weights : bool = False ,
516+ groups : int = 1 ,
498517 ) -> int :
499518 """
500519 If tensor holds some constant data, serialize it and return the
@@ -546,6 +565,16 @@ def get_serialized_buffer_index(
546565 dims = ((1 , 0 ) + tuple (range (2 , const_val .dim ())))
547566 ).contiguous ()
548567
568+ if swap_in_out_for_transpose_weights :
569+ shape = const_val .shape
570+ const_val = const_val .reshape (
571+ (groups , const_val .shape [0 ] // groups ) + const_val .shape [1 :]
572+ )
573+ const_val = const_val .permute ((0 , 2 , 1 ) + tuple (range (3 , const_val .dim ())))
574+ const_val = const_val .reshape (
575+ (shape [1 ] * groups , shape [0 ] // groups ) + shape [2 :]
576+ ).contiguous ()
577+
549578 if convert_to_nhwc :
550579 const_val = const_val .to (memory_format = torch .channels_last )
551580
0 commit comments