File tree Expand file tree Collapse file tree 2 files changed +15
-7
lines changed Expand file tree Collapse file tree 2 files changed +15
-7
lines changed Original file line number Diff line number Diff line change 11
11
12
12
import torch
13
13
14
- from torch .fx .experimental .symbolic_shapes import guard_size_oblivious
14
+ from torch .fx .experimental .symbolic_shapes import guard_or_false , guard_or_true
15
15
16
16
USE_TORCHDYNAMO_COMPILING_PATH : bool = False
17
17
@@ -91,8 +91,15 @@ def pt2_check_size_nonzero(x: torch.Tensor) -> torch.Tensor:
91
91
return x
92
92
93
93
94
- def pt2_guard_size_oblivious (x : bool ) -> bool :
94
+ def pt2_guard_or_false (x : bool ) -> bool :
95
95
if torch .jit .is_scripting () or not is_pt2_compiling ():
96
96
return x
97
97
98
- return guard_size_oblivious (x )
98
+ return guard_or_false (x )
99
+
100
+
101
+ def pt2_guard_or_true (x : bool ) -> bool :
102
+ if torch .jit .is_scripting () or not is_pt2_compiling ():
103
+ return x
104
+
105
+ return guard_or_true (x )
Original file line number Diff line number Diff line change 25
25
pt2_check_size_nonzero ,
26
26
pt2_checks_all_is_size ,
27
27
pt2_checks_tensor_slice ,
28
- pt2_guard_size_oblivious ,
28
+ pt2_guard_or_false ,
29
+ pt2_guard_or_true ,
29
30
)
30
31
from torchrec .streamable import Pipelineable
31
32
@@ -1071,7 +1072,7 @@ def _assert_tensor_has_no_elements_or_has_integers(
1071
1072
# TODO(ivankobzarev): Use guard_size_oblivious to pass tensor.numel() == 0 once it is torch scriptable.
1072
1073
return
1073
1074
1074
- assert pt2_guard_size_oblivious (tensor .numel () == 0 ) or tensor .dtype in [
1075
+ assert pt2_guard_or_false (tensor .numel () == 0 ) or tensor .dtype in [
1075
1076
torch .long ,
1076
1077
torch .int ,
1077
1078
torch .short ,
@@ -1206,7 +1207,7 @@ def _maybe_compute_length_per_key(
1206
1207
torch .sum (
1207
1208
pt2_check_size_nonzero (lengths .view (len (keys ), stride )), dim = 1
1208
1209
).tolist ()
1209
- if pt2_guard_size_oblivious (lengths .numel () != 0 )
1210
+ if pt2_guard_or_true (lengths .numel () != 0 )
1210
1211
else [0 ] * len (keys )
1211
1212
)
1212
1213
)
@@ -1425,7 +1426,7 @@ def _maybe_compute_kjt_to_jt_dict(
1425
1426
torch .ops .fbgemm .asynchronous_complete_cumsum (lengths )
1426
1427
for lengths in split_lengths
1427
1428
]
1428
- elif pt2_guard_size_oblivious (lengths .numel () > 0 ):
1429
+ elif pt2_guard_or_true (lengths .numel () > 0 ):
1429
1430
strided_lengths = lengths .view (len (keys ), stride )
1430
1431
if not torch .jit .is_scripting () and is_torchdynamo_compiling ():
1431
1432
torch ._check (strided_lengths .size (0 ) > 0 )
You can’t perform that action at this time.
0 commit comments