Skip to content

Commit c2b6d01

Browse files
emlinmeta-codesync[bot]
authored andcommitted
fix sparse_permute_1d kernel to support double dtype (#4969)
Summary: Pull Request resolved: #4969 X-link: https://github.com/facebookresearch/FBGEMM/pull/1986 Feature score collection will generate a weight tensor with fp64 dtype, we need to make sure the distribution kernels support this dtype Reviewed By: q10 Differential Revision: D83788055 fbshipit-source-id: ca8db60acf638413db839316ab27b3da6dc6785f
1 parent 3708b98 commit c2b6d01

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -943,7 +943,7 @@ std::tuple<Tensor, Tensor, std::optional<Tensor>> permute_1D_sparse_data_cpu(
943943
FBGEMM_DISPATCH_ALL_TYPES(
944944
indices.scalar_type(), "permute_1D_indices_weights_kernel_2", [&] {
945945
using indices_t = scalar_t;
946-
FBGEMM_DISPATCH_FLOAT_ONLY(
946+
FBGEMM_DISPATCH_FLOAT_AND_DOUBLE(
947947
weights.has_value() ? weights.value().scalar_type()
948948
: at::ScalarType::Float,
949949
"permute_1D_indices_weights_kernel_3",
@@ -2971,7 +2971,7 @@ std::tuple<Tensor, Tensor, std::optional<Tensor>> permute_sparse_features_cpu(
29712971
permuted_indices = at::empty(permuted_lengths_sum, indices.options());
29722972
AT_DISPATCH_INDEX_TYPES(
29732973
input_offsets.scalar_type(), "permute_data_kernel_1", ([&] {
2974-
FBGEMM_DISPATCH_FLOAT_ONLY(
2974+
FBGEMM_DISPATCH_FLOAT_AND_DOUBLE(
29752975
weights.has_value() ? weights.value().scalar_type()
29762976
: at::ScalarType::Float,
29772977
"permute_data_kernel_2",

fbgemm_gpu/src/sparse_ops/sparse_permute_1d.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ permute_1D_sparse_data_cuda(
141141
const auto weights_value_contig = weights_value.contiguous();
142142
permuted_weights =
143143
at::empty(permuted_indices_size, weights_value.options());
144-
FBGEMM_DISPATCH_ALL_TYPES(
144+
FBGEMM_DISPATCH_ALL_TYPES_AND_DOUBLE(
145145
weights_value.scalar_type(),
146146
"permute_1D_data_kernel_3",
147147
[&] {

0 commit comments

Comments
 (0)