File tree Expand file tree Collapse file tree 1 file changed +3
-6
lines changed
Expand file tree Collapse file tree 1 file changed +3
-6
lines changed Original file line number Diff line number Diff line change @@ -334,10 +334,9 @@ def quantized_add_meta(
334334 out_scale : float ,
335335 out_zero_point : int ,
336336) -> torch .Tensor :
337- out_size = X .size ()
338- if list (X .size ()) == [1 ]:
339- out_size = Y .size ()
340337
338+ # Determine output shape by broadcasting X and Y
339+ out_size = torch .broadcast_shapes (X .size (), Y .size ())
341340 return X .new_empty (out_size , dtype = X .dtype )
342341
343342
@@ -352,10 +351,8 @@ def quantized_add_per_tensor_meta(
352351 out_scale : float ,
353352 out_zero_point : int ,
354353) -> torch .Tensor :
355- out_size = X .size ()
356- if list (X .size ()) == [1 ]:
357- out_size = Y .size ()
358354
355+ out_size = torch .broadcast_shapes (X .size (), Y .size ())
359356 return X .new_empty (out_size , dtype = X .dtype )
360357
361358
You can’t perform that action at this time.
0 commit comments