Skip to content

Commit eaea63d

Browse files
laithsakkameta-codesync[bot]
authored andcommitted
remove guard_size_oblivious from torchrec jagged tensors. (#3431)
Summary: Pull Request resolved: #3431 keep intended semantics but use guard_or invariants. guard_size_oblivious will be deprecated soon. Reviewed By: TroyGarden Differential Revision: D83885644 fbshipit-source-id: dfcabef88e59f0a8c521f70d152070b344e3a237
1 parent 22406c2 commit eaea63d

File tree

2 files changed

+15
-7
lines changed

2 files changed

+15
-7
lines changed

torchrec/pt2/checks.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import torch
1313

14-
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
14+
from torch.fx.experimental.symbolic_shapes import guard_or_false, guard_or_true
1515

1616
USE_TORCHDYNAMO_COMPILING_PATH: bool = False
1717

@@ -91,8 +91,15 @@ def pt2_check_size_nonzero(x: torch.Tensor) -> torch.Tensor:
9191
return x
9292

9393

94-
def pt2_guard_size_oblivious(x: bool) -> bool:
94+
def pt2_guard_or_false(x: bool) -> bool:
9595
if torch.jit.is_scripting() or not is_pt2_compiling():
9696
return x
9797

98-
return guard_size_oblivious(x)
98+
return guard_or_false(x)
99+
100+
101+
def pt2_guard_or_true(x: bool) -> bool:
102+
if torch.jit.is_scripting() or not is_pt2_compiling():
103+
return x
104+
105+
return guard_or_true(x)

torchrec/sparse/jagged_tensor.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
pt2_check_size_nonzero,
2626
pt2_checks_all_is_size,
2727
pt2_checks_tensor_slice,
28-
pt2_guard_size_oblivious,
28+
pt2_guard_or_false,
29+
pt2_guard_or_true,
2930
)
3031
from torchrec.streamable import Pipelineable
3132

@@ -1071,7 +1072,7 @@ def _assert_tensor_has_no_elements_or_has_integers(
10711072
# TODO(ivankobzarev): Use guard_size_oblivious to pass tensor.numel() == 0 once it is torch scriptable.
10721073
return
10731074

1074-
assert pt2_guard_size_oblivious(tensor.numel() == 0) or tensor.dtype in [
1075+
assert pt2_guard_or_false(tensor.numel() == 0) or tensor.dtype in [
10751076
torch.long,
10761077
torch.int,
10771078
torch.short,
@@ -1206,7 +1207,7 @@ def _maybe_compute_length_per_key(
12061207
torch.sum(
12071208
pt2_check_size_nonzero(lengths.view(len(keys), stride)), dim=1
12081209
).tolist()
1209-
if pt2_guard_size_oblivious(lengths.numel() != 0)
1210+
if pt2_guard_or_true(lengths.numel() != 0)
12101211
else [0] * len(keys)
12111212
)
12121213
)
@@ -1425,7 +1426,7 @@ def _maybe_compute_kjt_to_jt_dict(
14251426
torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)
14261427
for lengths in split_lengths
14271428
]
1428-
elif pt2_guard_size_oblivious(lengths.numel() > 0):
1429+
elif pt2_guard_or_true(lengths.numel() > 0):
14291430
strided_lengths = lengths.view(len(keys), stride)
14301431
if not torch.jit.is_scripting() and is_torchdynamo_compiling():
14311432
torch._check(strided_lengths.size(0) > 0)

0 commit comments

Comments
 (0)