diff --git a/fbgemm_gpu/src/sparse_ops/sparse_permute_1d.cu b/fbgemm_gpu/src/sparse_ops/sparse_permute_1d.cu index f60ec77637..fec498b3b8 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_permute_1d.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_permute_1d.cu @@ -133,7 +133,7 @@ permute_1D_sparse_data_cuda( AT_DISPATCH_INDEX_TYPES( input_offsets.scalar_type(), "permute_1D_data_kernel_1", [&] { using offsets_t = index_t; - FBGEMM_DISPATCH_ALL_TYPES( + FBGEMM_DISPATCH_ALL_TYPES_AND_DOUBLE( indices.scalar_type(), "permute_1D_data_kernel_2", [&] { using indices_t = scalar_t; if (weights.has_value()) { @@ -141,7 +141,7 @@ permute_1D_sparse_data_cuda( const auto weights_value_contig = weights_value.contiguous(); permuted_weights = at::empty(permuted_indices_size, weights_value.options()); - FBGEMM_DISPATCH_ALL_TYPES( + FBGEMM_DISPATCH_ALL_TYPES_AND_DOUBLE( weights_value.scalar_type(), "permute_1D_data_kernel_3", [&] {