Skip to content

Commit 5b9f040

Browse files
pytorchbotangelayi
andauthored
Symintify fused_scaled_matmul_reduce_scatter (pytorch#167122)
Symintify fused_scaled_matmul_reduce_scatter (pytorch#165086) Fixes #ISSUE_NUMBER Pull Request resolved: pytorch#165086 Approved by: https://github.com/zou3519, https://github.com/Skylion007 (cherry picked from commit 4a0df39) Co-authored-by: angelayi <[email protected]>
1 parent 49046e0 commit 5b9f040

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

torch/distributed/_symmetric_memory/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,7 @@ def get_p2p_buf(rank: int, idx: int) -> torch.Tensor:
450450
lib.define(
451451
"fused_scaled_matmul_reduce_scatter("
452452
"Tensor A, Tensor B, Tensor A_scale, Tensor B_scale, "
453-
"str reduce_op, int orig_scatter_dim, int scatter_dim_after_maybe_reshape, str group_name, int[]? output_shape, "
453+
"str reduce_op, int orig_scatter_dim, int scatter_dim_after_maybe_reshape, str group_name, SymInt[]? output_shape, "
454454
"Tensor? bias = None, "
455455
"Tensor? result_scale = None, "
456456
"ScalarType? out_dtype = None, "

0 commit comments

Comments
 (0)