@@ -914,7 +914,7 @@ def generate_conv2d_line_buffer_fn(
914914 return generated_code
915915
916916 @staticmethod
917- def permute_config_gen (name : str , shape : tuple [int , ...], perm : tuple [int , ...]):
917+ def transpose_config_gen (name : str , shape : tuple [int , ...], perm : tuple [int , ...]):
918918 """
919919 Generate new shape and perm_strides for a permute operation. Operates by mapping the output index
920920 to input input index by:
@@ -933,12 +933,20 @@ def permute_config_gen(name: str, shape: tuple[int, ...], perm: tuple[int, ...])
933933 perm (tuple[int, ...]): The permutation of the dimensions.
934934
935935 Returns:
936- (new_shape, perm_strides) (tuple, tuple): the output shape and permutation strides .
936+ dict: Dictionary containing the configuration .
937937 """
938938 new_shape = tuple (shape [i ] for i in perm )
939939 strides = np .cumprod ((shape [1 :] + (1 ,))[::- 1 ])[::- 1 ]
940940 perm_strides = tuple (int (strides [i ]) for i in perm )
941- return (new_shape , perm_strides )
941+ return dict (
942+ dims = len (shape ),
943+ N = math .prod (shape ),
944+ from_shape = ', ' .join (str (x ) for x in shape ),
945+ perm = ', ' .join (str (x ) for x in perm ),
946+ perm_strides = ', ' .join (str (x ) for x in perm_strides ),
947+ to_shape = ', ' .join (str (x ) for x in new_shape ),
948+ config_name = name ,
949+ )
942950
943951 @model_optimizer ()
944952 def write_hls (self , model ):
0 commit comments