@@ -343,10 +343,9 @@ def define_tensor( # noqa: C901
343343 xnn_graph : XNNGraph ,
344344 vals_to_ids : Dict [torch .fx .Node , int ],
345345 convert_to_nhwc : bool = False ,
346- swap_nc_for_depthwise_weights : bool = False ,
346+ swap_in_out_for_weights : bool = False ,
347347 quant_params : Optional [QuantParams ] = None ,
348348 fp32_static_weights : bool = False ,
349- swap_in_out_for_transpose_weights : bool = False ,
350349 groups : int = 1 ,
351350 ) -> None :
352351 """
@@ -359,19 +358,21 @@ def define_tensor( # noqa: C901
359358 their corresponding ids in XNNGraph
360359 convert_to_nhwc: bool to indicate whether tensor shape should be permuted to
361360 reflect the nhwc memory format.
362- swap_nc_for_depthwise_weights : bool to indicate whether tensor shape
363- should be permuted such that the N and C dimensions are
364- swapped , which should be used for depthwise convolution
361+ swap_in_out_for_weights : bool to indicate whether tensor shape should be
362+ permuted and reshape from (inc, oc/groups, height, width) to (oc, inc/groups, height, width)
363+ , which should be used for depthwise/transpose convolution
365364 weights. This is only valid for tensors which hold
366365 constant data. If used along with convert_to_nhwc, this
367366 swap will happen before converting to nhwc.
368367 quant_params: Quantization meta data for this tensor, None if it is not quantized
369368 fp32_static_weights: XNN_FLAG_FP32_STATIC_WEIGHTS for fp16 conv
370- swap_in_out_for_transpose_weights: bool to indicate whether tensor shape should be
371- permuted and reshape from (inc, oc/groups, height, width) to (oc, inc/groups, height, width)
372- groups: number of groups for swap_in_out_for_transpose_weights
369+ groups: number of groups for swap_in_out_for_weights
373370 """
374371
372+ assert (
373+ swap_in_out_for_weights or groups == 1
374+ ), "groups is option for swap_in_out_for_weights"
375+
375376 if tensor in vals_to_ids :
376377 return
377378
@@ -399,18 +400,15 @@ def define_tensor( # noqa: C901
399400 xnn_graph ,
400401 vals_to_ids ,
401402 convert_to_nhwc ,
402- swap_nc_for_depthwise_weights ,
403+ swap_in_out_for_weights ,
403404 quant_params ,
404405 fp32_static_weights ,
405- swap_in_out_for_transpose_weights ,
406406 groups ,
407407 )
408408
409409 # convert tensor shape must reflect memory format, default is contiguous, so
410410 # only permute shape if we are converting the tensor to nhwc format
411- if swap_nc_for_depthwise_weights :
412- dims = [dims [1 ], dims [0 ]] + dims [2 :]
413- if swap_in_out_for_transpose_weights :
411+ if swap_in_out_for_weights :
414412 dims = [dims [1 ] * groups , dims [0 ] // groups ] + dims [2 :]
415413 if convert_to_nhwc :
416414 check_or_raise (len (dims ) == 4 , "Converting to nhwc requires 4d tensor" )
@@ -431,24 +429,16 @@ def define_tensor( # noqa: C901
431429 )
432430
433431 # Override the quant params axis since we have
434- # updated the weights for depthwise, with that the out_channels dim
432+ # updated the weights for depthwise/ transposed_conv2d , with that the out_channels dim
435433 # will be dims[3] instead of dims[0]. Let's update the per_channel
436434 # quant axis to match the new weight tensor before serializing
437- if swap_nc_for_depthwise_weights and (
438- quant_params and quant_params .per_channel
439- ):
440- if quant_params .axis == 0 :
441- quant_params .axis = len (dims ) - 1
442- else :
443- assert f"Unsupported weight per channel quantization axis for depthwise conv2d: { quant_params .axis } , expecting 0."
444-
445- if swap_in_out_for_transpose_weights and (
446- quant_params and quant_params .per_channel
447- ):
435+ if swap_in_out_for_weights and (quant_params and quant_params .per_channel ):
448436 if quant_params .axis == 0 :
449437 quant_params .axis = len (dims ) - 1
438+ elif quant_params .axis == 1 :
439+ quant_params .axis = 0
450440 else :
451- assert f"Unsupported weight per channel quantization axis for conv_transpose2d: { quant_params .axis } , expecting 0."
441+ assert f"Unsupported weight per channel quantization axis for depthwise conv2d / conv_transpose2d : { quant_params .axis } , expecting 0 / 1 ."
452442
453443 # Serialize tensor value
454444 ser_val = (
@@ -509,10 +499,9 @@ def get_serialized_buffer_index(
509499 xnn_graph : XNNGraph ,
510500 vals_to_ids : Dict [torch .fx .Node , int ],
511501 convert_to_nhwc : bool ,
512- swap_nc_for_depthwise_weights : bool ,
502+ swap_in_out_for_weights : bool ,
513503 quant_params : Optional [QuantParams ],
514504 fp32_static_weights : bool = False ,
515- swap_in_out_for_transpose_weights : bool = False ,
516505 groups : int = 1 ,
517506 ) -> int :
518507 """
@@ -526,24 +515,30 @@ def get_serialized_buffer_index(
526515 their corresponding ids in XNNGraph
527516 convert_to_nhwc: bool to indicate whether tensor shape should be permuted to
528517 reflect the nhwc memory format.
529- swap_nc_for_depthwise_weights : bool to indicate whether tensor shape
530- should be permuted such that the N and C dimensions are
531- swapped , which should be used for depthwise convolution
518+ swap_in_out_for_weights : bool to indicate whether tensor shape should be
519+ permuted and reshape from (inc, oc/groups, height, width) to (oc, inc/groups, height, width)
520+ , which should be used for depthwise/transpose convolution
532521 weights. This is only valid for tensors which hold
533522 constant data. If used along with convert_to_nhwc, this
534523 swap will happen before converting to nhwc.
535524 quant_params: Quantization meta data for this tensor, None if it is not quantize
536525 fp32_static_weights: bool to indicate whether tensor is fp32 static weights
526+ groups: groups for swap_in_out_for_weights
537527
538528 Returns:
539529 buffer_idx: idx of the serialized data. 0 If not associated constant
540530 data
541531 """
532+
533+ assert (
534+ swap_in_out_for_weights or groups == 1
535+ ), "groups is option for swap_in_out_for_weights"
536+
542537 # The get_attr node is the input to quant_params.
543538 get_attr_node = tensor if quant_params is None else quant_params .q_input
544539 if not is_param_node (self .exported_program , get_attr_node ):
545540 check_or_raise (
546- not swap_nc_for_depthwise_weights ,
541+ not swap_in_out_for_weights ,
547542 "Swapping N and C dimensions is only valid for constant data tensors" ,
548543 )
549544 return 0
@@ -560,12 +555,9 @@ def get_serialized_buffer_index(
560555 # ensure that the const is fp32
561556 const_val = const_val .to (dtype = torch .float32 ).contiguous ()
562557
563- if swap_nc_for_depthwise_weights :
564- const_val = const_val .permute (
565- dims = ((1 , 0 ) + tuple (range (2 , const_val .dim ())))
566- ).contiguous ()
567-
568- if swap_in_out_for_transpose_weights :
558+ if swap_in_out_for_weights :
559+ # Permute and reshape the tensor from (inc, oc/groups, height, width) to (oc, inc/groups, height, width)
560+ # which should be used for depthwise/transpose convolution weights for XNNPACK
569561 shape = const_val .shape
570562 const_val = const_val .reshape (
571563 (groups , const_val .shape [0 ] // groups ) + const_val .shape [1 :]
0 commit comments