@@ -243,16 +243,29 @@ def execute_node(self, context, graph):
243243 inp_values = context [node .input [0 ]]
244244 th_val = context [node .input [1 ]]
245245 out_bias = self .get_nodeattr ("ActVal" )
246- # MT expects inputs to be in the shape (N,C,H,W) or (N, C)
247- # if 4D then input values in context are (N,H,W,C) and need to
248- # be transposed.
249- # if 2D then inputs can be passed directly to MT function
250- is_4d = len (inp_values .shape ) == 4
251- if is_4d :
252- inp_values = np .transpose (inp_values , (0 , 3 , 1 , 2 ))
246+
247+ # Consider the data layout for transposing the input into the format
248+ # accepted by the multithreshold function above, i.e, the channel
249+ # dimension is along the axis with index 1.
250+ data_layout = None
251+ # If there is no layout annotation, guess based on rank of the tensor
252+ # TODO: Currently there is no mechanism here to get the layout
253+ # annotation, we allways guess, but this matches the previous behavior.
254+ if len (inp_values .shape ) < 5 :
255+ # Maps tensor rank to layout annotation
256+ rank_to_layout = {0 : None , 1 : "C" , 2 : "NC" , 3 : "NWC" , 4 : "NHWC" }
257+ # Lookup the layout required by this input shape
258+ data_layout = rank_to_layout [len (inp_values .shape )]
259+ # Lookup the index of the channel dimension in the data layout
260+ # Note: Assumes there is at most one "C" which denotes the channel
261+ # dimension
262+ cdim = data_layout .index ("C" ) if "C" in data_layout else 1
263+ # Rearrange the input to the expected (N, C, ...) layout
264+ inp_values = inp_values .swapaxes (cdim , 1 )
253265 y = multithreshold (inp_values , th_val , out_bias = out_bias )
254- if is_4d :
255- y = y .transpose (0 , 2 , 3 , 1 )
266+ # Rearrange the output back to the original layout
267+ y = y .swapaxes (cdim , 1 )
268+
256269 act = DataType [self .get_nodeattr ("outputDataType" )]
257270 if act == DataType ["BIPOLAR" ]:
258271 # binary to bipolar
0 commit comments