Skip to content

Commit 19c25c4

Browse files
suvadeep89facebook-github-bot
authored andcommitted
Broadcast implementation in quantized_add (#10903)
Summary: Pull Request resolved: #10903 This implementation introduces broadcast support in quantized_add by doing the following steps: 1) Compute the output tensor's shape using `torch.broadcast_shapes` 2) Implement the stride and offset logic for indexing into X and Y when using broadcast 3) Adds new tests for broadcast in quantized_add Reviewed By: hsharma35 Differential Revision: D74773433
1 parent a63a648 commit 19c25c4

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed

backends/cadence/aot/ops_registrations.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -334,10 +334,9 @@ def quantized_add_meta(
334334
out_scale: float,
335335
out_zero_point: int,
336336
) -> torch.Tensor:
337-
out_size = X.size()
338-
if list(X.size()) == [1]:
339-
out_size = Y.size()
340337

338+
# Determine output shape by broadcasting X and Y
339+
out_size = torch.broadcast_shapes(X.size(), Y.size())
341340
return X.new_empty(out_size, dtype=X.dtype)
342341

343342

@@ -352,10 +351,8 @@ def quantized_add_per_tensor_meta(
352351
out_scale: float,
353352
out_zero_point: int,
354353
) -> torch.Tensor:
355-
out_size = X.size()
356-
if list(X.size()) == [1]:
357-
out_size = Y.size()
358354

355+
out_size = torch.broadcast_shapes(X.size(), Y.size())
359356
return X.new_empty(out_size, dtype=X.dtype)
360357

361358

0 commit comments

Comments
 (0)