@@ -100,7 +100,7 @@ def get_nodeattr_types(self):
100100 "out_dtype" : ("s" , True , "" ),
101101 "out_scale" : ("f" , False , 1.0 ),
102102 "out_bias" : ("f" , False , 0.0 ),
103- "data_layout" : ("s" , False , "NCHW " ),
103+ "data_layout" : ("s" , False , "" ),
104104 }
105105
106106 def make_shape_compatible_op (self , model ):
@@ -130,12 +130,32 @@ def execute_node(self, context, graph):
130130 # retrieve attributes if output scaling is used
131131 out_scale = self .get_nodeattr ("out_scale" )
132132 out_bias = self .get_nodeattr ("out_bias" )
133- # transpose input if NHWC data layout is chosen
133+
134+ # Consider the data layout for transposing the input into the format
135+ # accepted by the multithreshold function above, i.e, the channel
136+ # dimension is along the axis with index 1.
134137 data_layout = self .get_nodeattr ("data_layout" )
135- channels_last = True if data_layout [- 1 ] == "C" else False
136- # calculate output
138+ # If there is no layout annotation, guess based on rank of the
139+ # tensor
140+ if not data_layout and len (v .shape ) < 5 :
141+ # Maps tensor rank to layout annotation
142+ rank_to_layout = {0 : None , 1 : None , 2 : "NC" , 3 : "NWC" , 4 : "NCHW" }
143+ # Lookup the layout required by this input shape
144+ data_layout = rank_to_layout [len (v .shape )]
145+ # Lookup the index of the channel dimension in the data layout
146+ # Note: Assumes there is at most one "C" which denotes the channel
147+ # dimension
148+ if data_layout is not None :
149+ cdim = data_layout .index ("C" ) if "C" in data_layout else 1
150+ else :
151+ cdim = 1
152+ # Rearrange the input to the expected (N, C, ...) layout
137153 orig_shape = v .shape
138- output = multithreshold (v , thresholds , out_scale , out_bias , channels_last )
154+ v = v .swapaxes (cdim , 1 )
155+ # Now we can use the multithreshold function to calculate output
156+ output = multithreshold (v , thresholds , out_scale , out_bias )
157+ # Rearrange the output back to the original layout
158+ output = output .swapaxes (cdim , 1 )
139159 assert output .shape == orig_shape , "Shape changed during thresholding!"
140160 context [node .output [0 ]] = output
141161
0 commit comments