Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ Tensor {{ mdesc }}_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ desc
{%- endif %}

// short-circuit if there are zero indices.
if (TORCH_GUARD_SIZE_OBLIVIOUS(indices.sym_numel().sym_eq(0))) {
if (TORCH_GUARD_OR_FALSE(indices.sym_numel().sym_eq(0))) {
{%- if dense %}
return grad_dev_weights;
{%- elif optimizer == "none" %}
Expand Down Expand Up @@ -213,7 +213,7 @@ Tensor {{ mdesc }}_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ desc

// Took allocation from https://www.internalfb.com/code/fbsource/fbcode/deeplearning/fbgemm/fbgemm_gpu/src/split_embeddings_utils.cu?lines=339-347
Tensor sorted_linear_indices_run;
if (TORCH_GUARD_SIZE_OBLIVIOUS(total_unique_indices.sym_gt(0))) {
if (TORCH_GUARD_OR_TRUE(total_unique_indices.sym_gt(0))) {
sorted_linear_indices_run = at::empty_symint({total_unique_indices}, indices.options());
} else {
sorted_linear_indices_run = at::empty_like(indices);
Expand Down
8 changes: 4 additions & 4 deletions fbgemm_gpu/fbgemm_gpu/sparse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@

import torch.utils._pytree as pytree
from torch import SymInt, Tensor
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
from torch.fx.experimental.symbolic_shapes import guard_or_true


if hasattr(torch.library, "register_fake"):
Expand Down Expand Up @@ -251,7 +251,7 @@ def tbe_input_combine_abstract(
torch._check(index.is_contiguous())
torch._check(offset.is_contiguous())
total_indices = total_indices + index.numel()
if guard_size_oblivious(weight.numel() > 0):
if guard_or_true(weight.numel() > 0):
torch._check(weight.dim() == 1)
torch._check(weight.numel() == index.numel())
torch._check(weight.is_contiguous())
Expand Down Expand Up @@ -288,7 +288,7 @@ def tbe_input_combine_with_length_abstract(
torch._check(offset.is_contiguous())
total_indices = total_indices + index.numel()
total_offsets = total_offsets + offset.numel()
if guard_size_oblivious(weight.numel() > 0):
if guard_or_true(weight.numel() > 0):
torch._check(weight.dim() == 1)
torch._check(weight.numel() == index.numel())
torch._check(weight.is_contiguous())
Expand Down Expand Up @@ -807,7 +807,7 @@ def batch_index_select_dim0_forward_cpu_impl_abstract(
torch._check(num_inputs == len(input_rows))
torch._check(num_inputs == len(input_columns))

if permute_output_dim_0_1 and guard_size_oblivious(len(input_num_indices) > 0):
if permute_output_dim_0_1 and guard_or_true(len(input_num_indices) > 0):
# All num_indices must be the same if permute_output_dim_0_1 is True
for x in input_num_indices:
torch._check(x == input_num_indices[0])
Expand Down
Loading