|
49 | 49 |
|
50 | 50 | import torch.utils._pytree as pytree
|
51 | 51 | from torch import SymInt, Tensor
|
52 |
| -from torch.fx.experimental.symbolic_shapes import guard_size_oblivious |
| 52 | +from torch.fx.experimental.symbolic_shapes import guard_or_true |
53 | 53 |
|
54 | 54 |
|
55 | 55 | if hasattr(torch.library, "register_fake"):
|
@@ -251,7 +251,7 @@ def tbe_input_combine_abstract(
|
251 | 251 | torch._check(index.is_contiguous())
|
252 | 252 | torch._check(offset.is_contiguous())
|
253 | 253 | total_indices = total_indices + index.numel()
|
254 |
| - if guard_size_oblivious(weight.numel() > 0): |
| 254 | + if guard_or_true(weight.numel() > 0): |
255 | 255 | torch._check(weight.dim() == 1)
|
256 | 256 | torch._check(weight.numel() == index.numel())
|
257 | 257 | torch._check(weight.is_contiguous())
|
@@ -288,7 +288,7 @@ def tbe_input_combine_with_length_abstract(
|
288 | 288 | torch._check(offset.is_contiguous())
|
289 | 289 | total_indices = total_indices + index.numel()
|
290 | 290 | total_offsets = total_offsets + offset.numel()
|
291 |
| - if guard_size_oblivious(weight.numel() > 0): |
| 291 | + if guard_or_true(weight.numel() > 0): |
292 | 292 | torch._check(weight.dim() == 1)
|
293 | 293 | torch._check(weight.numel() == index.numel())
|
294 | 294 | torch._check(weight.is_contiguous())
|
@@ -807,7 +807,7 @@ def batch_index_select_dim0_forward_cpu_impl_abstract(
|
807 | 807 | torch._check(num_inputs == len(input_rows))
|
808 | 808 | torch._check(num_inputs == len(input_columns))
|
809 | 809 |
|
810 |
| - if permute_output_dim_0_1 and guard_size_oblivious(len(input_num_indices) > 0): |
| 810 | + if permute_output_dim_0_1 and guard_or_true(len(input_num_indices) > 0): |
811 | 811 | # All num_indices must be the same if permute_output_dim_0_1 is True
|
812 | 812 | for x in input_num_indices:
|
813 | 813 | torch._check(x == input_num_indices[0])
|
|
0 commit comments