@@ -8,14 +8,14 @@ def __convert_max_elementwise(ctx):
88 input_b = ctx .method_args [1 ]
99 output = ctx .method_return
1010 input_a_trt , input_b_trt = add_missing_trt_tensors (ctx .network , [input_a , input_b ])
11- input_a_trt , input_b_trt = broadcast_trt_tensors (ctx .network , [input_a_trt , input_b_trt ], output .ndim - 1 )
11+ input_a_trt , input_b_trt = broadcast_trt_tensors (ctx .network , [input_a_trt , input_b_trt ], len ( output .shape ) - 1 )
1212 layer = ctx .network .add_elementwise (input_a_trt , input_b_trt , trt .ElementWiseOperation .MAX )
1313 output ._trt = layer .get_output (0 )
1414
1515
1616def __convert_max_reduce (ctx ):
1717 input = ctx .method_args [0 ]
18- dim = get_arg (ctx , 'dim' , pos = 1 , default = tuple (range (1 , input .ndim )))
18+ dim = get_arg (ctx , 'dim' , pos = 1 , default = tuple (range (1 , len ( input .shape ) )))
1919 keepdim = get_arg (ctx , 'keepdim' , pos = 2 , default = False )
2020 input_trt = add_missing_trt_tensors (ctx .network , [input ])[0 ]
2121 output_val = ctx .method_return [0 ]
@@ -59,4 +59,4 @@ def forward(self, x, y):
5959@add_module_test (torch .float32 , torch .device ('cuda' ), [(1 , 3 , 3 ), (1 ,)]) # broadcast
6060@add_module_test (torch .float32 , torch .device ('cuda' ), [(1 , 3 , 3 , 3 ), (1 , 3 , 3 )]) # broadcast
6161def test_max_elementwise ():
62- return MaxElementwise ()
62+ return MaxElementwise ()
0 commit comments