Skip to content

Commit 33e61aa

Browse files
laithsakkameta-codesync[bot]
authored andcommitted
use guard_or_true instead of guard_size_oblivious in sparse_ops.py (#4974)
Summary: Pull Request resolved: #4974 X-link: https://github.com/facebookresearch/FBGEMM/pull/1991 Have the same "intended" semantics of guard_size_oblivious here, if idk, i assume its True. guard_size_oblivuous will be deprecated. Reviewed By: q10 Differential Revision: D83884674 fbshipit-source-id: 0bd59c5502692702c87448c226153bd2dba6a044
1 parent 44a6cbf commit 33e61aa

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

fbgemm_gpu/fbgemm_gpu/sparse_ops.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949

5050
import torch.utils._pytree as pytree
5151
from torch import SymInt, Tensor
52-
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
52+
from torch.fx.experimental.symbolic_shapes import guard_or_true
5353

5454

5555
if hasattr(torch.library, "register_fake"):
@@ -251,7 +251,7 @@ def tbe_input_combine_abstract(
251251
torch._check(index.is_contiguous())
252252
torch._check(offset.is_contiguous())
253253
total_indices = total_indices + index.numel()
254-
if guard_size_oblivious(weight.numel() > 0):
254+
if guard_or_true(weight.numel() > 0):
255255
torch._check(weight.dim() == 1)
256256
torch._check(weight.numel() == index.numel())
257257
torch._check(weight.is_contiguous())
@@ -288,7 +288,7 @@ def tbe_input_combine_with_length_abstract(
288288
torch._check(offset.is_contiguous())
289289
total_indices = total_indices + index.numel()
290290
total_offsets = total_offsets + offset.numel()
291-
if guard_size_oblivious(weight.numel() > 0):
291+
if guard_or_true(weight.numel() > 0):
292292
torch._check(weight.dim() == 1)
293293
torch._check(weight.numel() == index.numel())
294294
torch._check(weight.is_contiguous())
@@ -807,7 +807,7 @@ def batch_index_select_dim0_forward_cpu_impl_abstract(
807807
torch._check(num_inputs == len(input_rows))
808808
torch._check(num_inputs == len(input_columns))
809809

810-
if permute_output_dim_0_1 and guard_size_oblivious(len(input_num_indices) > 0):
810+
if permute_output_dim_0_1 and guard_or_true(len(input_num_indices) > 0):
811811
# All num_indices must be the same if permute_output_dim_0_1 is True
812812
for x in input_num_indices:
813813
torch._check(x == input_num_indices[0])

0 commit comments

Comments
 (0)