|
14 | 14 | #include <ATen/native/xpu/sycl/Indexing.h>
|
15 | 15 | #include <ATen/native/xpu/sycl/IndexingUtils.h>
|
16 | 16 | #include <ATen/native/xpu/sycl/Loops.h>
|
| 17 | +#include <ATen/native/xpu/sycl/SortingKernels.h> |
17 | 18 | #include <ATen/native/xpu/sycl/pstl/PSTLFunctions.h>
|
18 | 19 | #include <ATen/ops/_sparse_coo_tensor_with_dims_and_tensors.h>
|
19 | 20 | #include <ATen/ops/arange.h>
|
@@ -675,42 +676,66 @@ void index_put_deterministic_kernel(
|
675 | 676 |
|
676 | 677 | linearIndex.divide_(sliceSize, "trunc");
|
677 | 678 |
|
678 |
| - sorted_indices.copy_(linearIndex); |
679 |
| - pstl::itoa( |
680 |
| - orig_indices.data_ptr<int64_t>(), |
681 |
| - orig_indices.data_ptr<int64_t>() + linearIndex.numel(), |
682 |
| - (int64_t)0); |
683 |
| - pstl::sort<int64_t, int64_t>( |
| 679 | + auto range = at::arange(num_indices, linearIndex.options()); |
| 680 | + sort_pairs<int64_t, int64_t>( |
684 | 681 | linearIndex.const_data_ptr<int64_t>(),
|
685 |
| - sorted_indices.data_ptr<int64_t>(), |
686 |
| - orig_indices.data_ptr<int64_t>(), |
687 |
| - linearIndex.numel(), |
| 682 | + sorted_indices.mutable_data_ptr<int64_t>(), |
| 683 | + range.const_data_ptr<int64_t>(), |
| 684 | + orig_indices.mutable_data_ptr<int64_t>(), |
| 685 | + num_indices, |
688 | 686 | false);
|
| 687 | + |
689 | 688 | TORCH_INTERNAL_ASSERT(
|
690 | 689 | linearIndex.numel() * sliceSize * nElemBefore == expandedValue.numel(),
|
691 | 690 | "number of flattened indices did not match number of elements in the value tensor: ",
|
692 | 691 | linearIndex.numel() * sliceSize * nElemBefore,
|
693 | 692 | " vs ",
|
694 | 693 | expandedValue.numel());
|
695 |
| - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( |
696 |
| - at::ScalarType::ComplexHalf, |
697 |
| - at::ScalarType::BFloat16, |
698 |
| - at::ScalarType::Half, |
699 |
| - at::ScalarType::Bool, |
700 |
| - expandedValue.scalar_type(), |
701 |
| - "index_put_deterministic_kernel", |
702 |
| - [&] { |
703 |
| - launch_index_put_deterministic_kernel<scalar_t>( |
704 |
| - sorted_indices.mutable_data_ptr<int64_t>(), |
705 |
| - orig_indices.mutable_data_ptr<int64_t>(), |
706 |
| - expandedValue.const_data_ptr<scalar_t>(), |
707 |
| - src_.mutable_data_ptr<scalar_t>(), |
708 |
| - num_indices, |
709 |
| - sliceSize, |
710 |
| - strideBefore, |
711 |
| - nElemBefore, |
712 |
| - accumulate); |
713 |
| - }); |
| 694 | + |
| 695 | + if (sliceSize > SIMD) { |
| 696 | + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( |
| 697 | + at::ScalarType::ComplexHalf, |
| 698 | + at::ScalarType::BFloat16, |
| 699 | + at::ScalarType::Half, |
| 700 | + at::ScalarType::Bool, |
| 701 | + expandedValue.scalar_type(), |
| 702 | + "index_put_deterministic_kernel", |
| 703 | + [&] { |
| 704 | + launch_index_put_deterministic_kernel<scalar_t, scalar_t>( |
| 705 | + sorted_indices.mutable_data_ptr<int64_t>(), |
| 706 | + orig_indices.mutable_data_ptr<int64_t>(), |
| 707 | + expandedValue.const_data_ptr<scalar_t>(), |
| 708 | + src_.mutable_data_ptr<scalar_t>(), |
| 709 | + num_indices, |
| 710 | + sliceSize, |
| 711 | + strideBefore, |
| 712 | + nElemBefore, |
| 713 | + accumulate); |
| 714 | + }); |
| 715 | + } else { |
| 716 | + // Align acc type with CUDA |
| 717 | + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( |
| 718 | + at::ScalarType::ComplexHalf, |
| 719 | + at::ScalarType::BFloat16, |
| 720 | + at::ScalarType::Half, |
| 721 | + at::ScalarType::Bool, |
| 722 | + expandedValue.scalar_type(), |
| 723 | + "index_put_deterministic_kernel", |
| 724 | + [&] { |
| 725 | + using accscalar_t = at::opmath_type<scalar_t>; |
| 726 | + launch_index_put_deterministic_kernel<scalar_t, accscalar_t>( |
| 727 | + sorted_indices.mutable_data_ptr<int64_t>(), |
| 728 | + orig_indices.mutable_data_ptr<int64_t>(), |
| 729 | + expandedValue.const_data_ptr<scalar_t>(), |
| 730 | + src_.mutable_data_ptr<scalar_t>(), |
| 731 | + num_indices, |
| 732 | + sliceSize, |
| 733 | + strideBefore, |
| 734 | + nElemBefore, |
| 735 | + accumulate); |
| 736 | + }); |
| 737 | + } |
| 738 | + |
714 | 739 | if (permuted)
|
715 | 740 | self.copy_(src_.permute(inversePerm));
|
716 | 741 | else if (!self_contiguous) {
|
@@ -1477,8 +1502,8 @@ void index_reduce_func_xpu_template(
|
1477 | 1502 | getTensorInfo<const index_t, unsigned int>(index);
|
1478 | 1503 | indexInfo.collapseDims();
|
1479 | 1504 |
|
1480 |
| - // A reasonable choice for when to have each thread iterate over |
1481 |
| - // index to choose |
| 1505 | + // A reasonable choice for when to have each thread iterate |
| 1506 | + // over index to choose |
1482 | 1507 | if (numIndex <= 16) {
|
1483 | 1508 | auto caller =
|
1484 | 1509 | SMALL_INDEX(scalar_t, index_t, unsigned int, func_t);
|
@@ -1707,8 +1732,8 @@ static inline ForwardIt find_bound(
|
1707 | 1732 | const T& value) {
|
1708 | 1733 | ForwardIt it;
|
1709 | 1734 | typename std::iterator_traits<ForwardIt>::difference_type count, step;
|
1710 |
| - // NOTE: std::distance(first, last) compiles but produces wrong results here, |
1711 |
| - // so only legacy random access iterators are safe in this code. |
| 1735 | + // NOTE: std::distance(first, last) compiles but produces wrong results |
| 1736 | + // here, so only legacy random access iterators are safe in this code. |
1712 | 1737 | count = last - first;
|
1713 | 1738 |
|
1714 | 1739 | while (count > 0) {
|
|
0 commit comments