diff --git a/src/ATen/native/xpu/sycl/Shape.cpp b/src/ATen/native/xpu/sycl/Shape.cpp index 12bd0ba66d..7119568085 100644 --- a/src/ATen/native/xpu/sycl/Shape.cpp +++ b/src/ATen/native/xpu/sycl/Shape.cpp @@ -1,34 +1,89 @@ +#include #include #include #include +#include #include +#include #include #include +#include +#include #include +#include #include #include +#include +#include #include #include +#include #include namespace at::native::xpu { -// The best performance is achieved for parallel computing with 1024 batch sizes -// at a time. -constexpr int CAT_ARRAY_BATCH_SIZE = 1024; +constexpr int CAT_ARRAY_BATCH_SIZE = 64; +constexpr int CAT_ARRAY_MAX_INPUT_DIMS = 4; +constexpr int ALIGNED_VEC_LOAD_BYTES_16 = 16; +constexpr int ALIGNED_VEC_LOAD_BYTES_8 = 8; -// Maximum parallel dimension to supporte -constexpr int CAT_ARRAY_MAX_INPUT_DIMS = 5; +inline bool is_aligned_vec4(const void* ptr) { + auto iptr = reinterpret_cast(ptr); + return !(iptr % alignof(uint4)); +} + +inline std::tuple, sycl::range<2>> getCatRange( + unsigned int max_elements_per_tensor, + ptrdiff_t nTensors) { + std::cout << "getCatRange---" << std::endl; + constexpr unsigned int items_per_group = 256; + constexpr unsigned int elements_per_item = 8; + constexpr unsigned int max_group_per_eu = 32; + + unsigned int max_items = ceil_div(max_elements_per_tensor, elements_per_item); + unsigned int item_groups = ceil_div(max_items, items_per_group); + + const unsigned int num_eu = syclGpuEUCountPerSubslice(); + item_groups = std::min(num_eu * max_group_per_eu, item_groups); + + sycl::range<2> global_range( + (long long)nTensors, items_per_group * item_groups); + sycl::range<2> local_range(1, item_groups); + return std::make_tuple(global_range, local_range); +} + +template +inline std::tuple, sycl::range<2>> getCatRangeContig( + unsigned int max_elements_per_tensor, + ptrdiff_t nTensors) { + std::cout << "getCatRangeContig---" << std::endl; + constexpr unsigned int items_per_group = 256; + constexpr unsigned int min_aligned_vec_per_item = 1; + constexpr unsigned int max_group_per_eu = 32; + + unsigned int elements_per_item = + aligned_vec_load_bytes / sizeof(T) * min_aligned_vec_per_item; + unsigned int max_items = ceil_div(max_elements_per_tensor, elements_per_item); + unsigned int item_groups = ceil_div(max_items, items_per_group); + + const unsigned int num_eu = syclGpuEUCountPerSubslice(); + item_groups = std::min(num_eu * max_group_per_eu, item_groups); + + sycl::range<2> global_range( + (long long)nTensors, item_groups * items_per_group); + sycl::range<2> local_range(1, items_per_group); + return std::make_tuple(global_range, local_range); +} // Similar to any other IndexToOffset calculation for copying along a given // dimension. template struct CatArrIndexToOffset { static inline IndexType compute( - const IndexType outputSize[Dims], - const IndexType outputStride[Dims], + const IndexType tensorSize[Dims], + const IndexType tensorStride[Dims], const IndexType dimSize, const unsigned int concatDim, IndexType linearIndex) { @@ -40,227 +95,422 @@ struct CatArrIndexToOffset { #pragma unroll for (int i = Dims - 1; i >= 1; --i) { - IndexType curDimSize = i == concatDim ? dimSize : outputSize[i]; + IndexType curDimSize = i == concatDim ? dimSize : tensorSize[i]; IndexType nextDimIndex = linearIndex / curDimSize; IndexType curDimIndex = linearIndex - curDimSize * nextDimIndex; - IndexType curDimOffset = curDimIndex * outputStride[i]; + IndexType curDimOffset = curDimIndex * tensorStride[i]; offset += curDimOffset; linearIndex = nextDimIndex; } - return offset + linearIndex * outputStride[0]; + return offset + linearIndex * tensorStride[0]; } }; -template -struct CatArrInputTensor { - T* input; - IndexType offset; - IndexType dimSize; - IndexType nElements; +template +struct TensorSizeStride { + IndexType tensorSize[MaxDims]; + IndexType tensorStride[MaxDims]; }; -template -struct OutputTensorSizeStride { - IndexType outputSize[MaxDims]; - IndexType outputStride[MaxDims]; +// pass meta data directly through kernel argument instead of pin memory +// In contiguous case, we will not need stride_size, setting it as 1 as +// placeholder to pass compile. +template +struct CatArrInputTensorMetadata { + const T* input[n]; + IndexType offset[n]; + IndexType dimSize[n]; + IndexType nElements[n]; + bool isContiguous[n]; + TensorSizeStride + tensorStride[stride_size]; }; template < - typename Tout, - typename underlying_out_t, - typename Tin, - typename underlying_in_t, + typename T, typename IndexType, - int Dims> -struct CatArrayBatchedCopyKernelFunctor { + int Dims, + int batch_size, + int stride_size> +struct CatArrayBatchedCopy { void operator()(sycl::nd_item<2> item) const { - IndexType tid = item.get_global_id(1); - IndexType in = item.get_group(0); + IndexType tid = + item.get_group(1) * item.get_local_range(1) + item.get_local_id(1); + IndexType nElements = inputs.nElements[item.get_group(0)]; + TensorSizeStride ins = stride_size > 1 + ? inputs.tensorStride[item.get_group(0)] + : inputs.tensorStride[0]; + bool isContig = inputs.isContiguous[item.get_group(0)]; - IndexType nElements = inputs[in].nElements; + if (tid >= nElements) + return; + + const T* data = inputs.input[item.get_group(0)]; + IndexType offset = inputs.offset[item.get_group(0)]; + IndexType dimSize = inputs.dimSize[item.get_group(0)]; + IndexType dataOffset = offset * dimStride; + + IndexType stride = item.get_group_range(1) * item.get_local_range(1); + + while (tid < nElements) { + IndexType elementOffset = CatArrIndexToOffset::compute( + os.tensorSize, os.tensorStride, dimSize, concatDim, tid); + if (isContig) { + output[dataOffset + elementOffset] = data[tid]; + } else { + IndexType inElementOffset = + CatArrIndexToOffset::compute( + ins.tensorSize, ins.tensorStride, dimSize, concatDim, tid); + output[dataOffset + elementOffset] = data[inElementOffset]; + } + tid += stride; + } + } + + CatArrayBatchedCopy( + T* output, + CatArrInputTensorMetadata inputs, + TensorSizeStride os, + const int concatDim, + IndexType dimStride) + : output(output), + inputs(inputs), + os(os), + concatDim(concatDim), + dimStride(dimStride) {} + + private: + T* output; + CatArrInputTensorMetadata inputs; + TensorSizeStride os; + const int concatDim; + IndexType dimStride; +}; + +template < + typename T, + typename IndexType, + int Dims, + int batch_size, + int stride_size> +struct CatArrayBatchedCopy_contig { + void operator()(sycl::nd_item<2> item) const { + IndexType tid = + item.get_group(1) * item.get_local_range(1) + item.get_local_id(1); + IndexType nElements = inputs.nElements[item.get_group(0)]; if (tid >= nElements) return; - Tin* data = inputs[in].input; - IndexType offset = inputs[in].offset; - IndexType dimSize = inputs[in].dimSize; + const T* data = inputs.input[item.get_group(0)]; + IndexType offset = inputs.offset[item.get_group(0)]; + IndexType dimSize = inputs.dimSize[item.get_group(0)]; IndexType dataOffset = offset * dimStride; - IndexType stride = item.get_global_range(1); + IndexType stride = item.get_group_range(1) * item.get_local_range(1); while (tid < nElements) { IndexType elementOffset = CatArrIndexToOffset::compute( - os.outputSize, os.outputStride, dimSize, concatDim, tid); + os.tensorSize, os.tensorStride, dimSize, concatDim, tid); output[dataOffset + elementOffset] = data[tid]; tid += stride; } } - CatArrayBatchedCopyKernelFunctor( - Tout* output_, - CatArrInputTensor* inputs_, - OutputTensorSizeStride os_, - const int concatDim_, - IndexType dimStride_) - : output(output_), - inputs(inputs_), - os(os_), - concatDim(concatDim_), - dimStride(dimStride_) {} + CatArrayBatchedCopy_contig( + T* output, + CatArrInputTensorMetadata inputs, + TensorSizeStride os, + const int concatDim, + IndexType dimStride) + : output(output), + inputs(inputs), + os(os), + concatDim(concatDim), + dimStride(dimStride) {} private: - Tout* output; - CatArrInputTensor* inputs; - OutputTensorSizeStride os; + T* output; + CatArrInputTensorMetadata inputs; + TensorSizeStride os; const int concatDim; IndexType dimStride; }; -/** - * Kernel used to concatenated grimDim.y tensors into an output tensor. Uses a - * grid-stride loop based off of the blockIdx.x, threadIdx.x for each input to - * copy each element from each input tensor into the output. - * - * output: base pointer to the storage associated with the output tensor - * inputs: GPU-allocated array of input metadata for each input to concatenate - * in the kernel - * os: the size/stride vectors for the output tensor - * concatDim: dimension along which we are concatenating - * dimStride: the stride of the output tensor at the concatDim - * - * The most important assumption made is that the input tensors are contiguous. - */ +/* + Specialized implementation of the CatArrayBatchedCopy written to generate wide + memory loads to improve memory bandwidth throughput. +*/ + template < - typename Tout, - typename underlying_out_t, - typename Tin, - typename underlying_in_t, + typename T, typename IndexType, - int Dims> -void CatArrayBatchedCopy( - Tout* output, - CatArrInputTensor* inputs, - OutputTensorSizeStride os, - const int concatDim, - IndexType dimStride, - int batchCounter) { - CatArrayBatchedCopyKernelFunctor< - Tout, - underlying_out_t, - Tin, - underlying_in_t, - IndexType, - Dims> - kfn(output, inputs, os, concatDim, dimStride); - - // Get grid where x dim fills half gpu and y dim is number of tensors. - // This will have cating two tensors fill the entire grid, but prevent - // many threads from needlessly load meta data if their sizes is small. - int64_t numWI = syclMaxWorkGroupSize(kfn); - - // We set limited numWG to prevent over schedule. - // numWG = 512 EUs * 8 threads * SIMD lanes 32 / max_compute_units - // (1024 on PVC). - // When input tensors less than 32, we choose 128 numWG to handle a tensor, - // then we have one tile per tensor. - // When input tensors more than 32, we choose 64 numWG to handle a tensor, - // half tile per tensor, the other half is occupied by next input tensor. - int64_t numWG; - if (batchCounter > 32) - numWG = 64; - else - numWG = 128; - sycl::range<2> global_range(batchCounter, numWG * numWI); - sycl::range<2> local_range(1, numWI); - auto& q = getCurrentSYCLQueue(); - - sycl_kernel_submit(global_range, local_range, q, kfn); -} + int Dims, + int batch_size, + int stride_size, + int aligned_vec_load_bytes> +struct CatArrayBatchedCopy_alignedK_contig { + void operator()(sycl::nd_item<2> item) const { + // This kernel tries to use aligned_vec_load_bytes*8 bit loads + // Special case 2-byte types to use 8-byte vec loads to reduce register + // pressure The below lambda is to allow cc compiler to pass kILP>0 checks + // for large types (e.g. ComplexDouble, 16 bytes) + constexpr int kILP = aligned_vec_load_bytes / sizeof(T) > 0 + ? aligned_vec_load_bytes / sizeof(T) + : ALIGNED_VEC_LOAD_BYTES_16 / sizeof(T); + + IndexType inputOffset = + (item.get_group(1) * item.get_local_range(1) + item.get_local_id(1)) * + kILP; + IndexType inputStride = + item.get_group_range(1) * item.get_local_range(1) * kILP; + + IndexType nElements = inputs.nElements[item.get_group(0)]; + if (inputOffset >= nElements) { + return; + } -template < - typename scalar_out_t, - typename underlying_out_t, - typename scalar_in_t, - typename underlying_in_t> + const T* data = inputs.input[item.get_group(0)]; + IndexType offset = inputs.offset[item.get_group(0)]; + IndexType dimSize = inputs.dimSize[item.get_group(0)]; + IndexType dataOffset = offset * dimStride; + + IndexType v_elementOffset[kILP]; + T reg_data[kILP]; + + while (inputOffset + kILP <= nElements) { + for (int i = 0; i < kILP; ++i) { + v_elementOffset[i] = CatArrIndexToOffset::compute( + os.tensorSize, + os.tensorStride, + dimSize, + concatDim, + inputOffset + i); + } + + using LT = memory::aligned_vector; + ((LT*)reg_data)[0] = const_cast((LT*)(data + inputOffset))[0]; + +#pragma unroll + for (int i = 0; i < kILP; ++i) { + output[dataOffset + v_elementOffset[i]] = reg_data[i]; + } + + inputOffset += inputStride; + } + + // Handle remaining tail in case nElements does not divide + // exactly to kILP + + while (inputOffset < nElements) { + v_elementOffset[0] = CatArrIndexToOffset::compute( + os.tensorSize, os.tensorStride, dimSize, concatDim, inputOffset); + output[dataOffset + v_elementOffset[0]] = data[inputOffset]; + inputOffset++; + } + } + + CatArrayBatchedCopy_alignedK_contig( + T* output, + CatArrInputTensorMetadata inputs, + TensorSizeStride os, + const int concatDim, + IndexType dimStride) + : output(output), + inputs(inputs), + os(os), + concatDim(concatDim), + dimStride(dimStride) {} + + private: + T* output; + CatArrInputTensorMetadata inputs; + TensorSizeStride os; + const int concatDim; + IndexType dimStride; +}; + +template void parallel_cat( const Tensor& out, const MaterializedITensorListRef& inputs, int64_t dimension, - int nDims) { + int nDims, + c10::MemoryFormat memory_format) { // First, let's set up our kernel parameters. We start with a raw pointer to // the storage for the output Tensor. - scalar_out_t* data = static_cast(out.mutable_data_ptr()); - - // Kernel Parameter - int64_t tensorMetadataSize = - sizeof(CatArrInputTensor) * - CAT_ARRAY_BATCH_SIZE; - auto d_inputs_storage = - at::empty({tensorMetadataSize}, out.options().dtype(at::kByte)); - auto d_inputs = static_cast*>( - d_inputs_storage.mutable_data_ptr()); - - OutputTensorSizeStride param; - - for (int i = 0; i < nDims; ++i) { - param.outputSize[i] = at::native::size(out, i); - param.outputStride[i] = out.stride(i); + scalar_t* data = (scalar_t*)(out.mutable_data_ptr()); + CatArrInputTensorMetadata + catMetaData; + TensorSizeStride outputParam; + + // Next, let's initialize the size, stride arrays for the output Tensor. + if (memory_format == c10::MemoryFormat::Contiguous) { + for (int i = 0; i < nDims; ++i) { + outputParam.tensorSize[i] = out.size(i); + outputParam.tensorStride[i] = out.stride(i); + } + } else if ( + memory_format == c10::MemoryFormat::ChannelsLast || + memory_format == c10::MemoryFormat::ChannelsLast3d) { + // permute the semantics of dims from NCHW to NHWC so that the input + // tensor is now contiguous + outputParam.tensorSize[0] = out.size(0); + outputParam.tensorStride[0] = out.stride(0); + for (int i = 1; i < nDims - 1; ++i) { + outputParam.tensorSize[i] = out.size(i + 1); + outputParam.tensorStride[i] = out.stride(i + 1); + } + outputParam.tensorSize[nDims - 1] = out.size(1); + outputParam.tensorStride[nDims - 1] = out.stride(1); + } else { + TORCH_CHECK(false, "unsupported memory format"); } + // If all batches are contiguous we can call a specialized implementation + // which requires the input tensor addresses to be aligned to a + // 16 Byte boundary. + + bool isContig = true; + bool isAligned = true; + unsigned int max_elements_per_tensor = 0; + // Now we loop - auto& q = getCurrentSYCLQueue(); int batchCounter = 0; int64_t offset = 0; - for (int i = 0; i < inputs.size(); i += CAT_ARRAY_BATCH_SIZE) { - // Re-allocate stackInputs every iteration to avoid read-after-write hazard - { - CatArrInputTensor* stackInputs; - - auto stackInputs_dptr = - at::getHostAllocator(at::kXPU)->allocate(tensorMetadataSize); - stackInputs = - (CatArrInputTensor*)stackInputs_dptr.get(); - - for (batchCounter = 0; batchCounter < CAT_ARRAY_BATCH_SIZE && - (i + batchCounter) < inputs.size(); - ++batchCounter) { - int64_t dimSize = - at::native::size(inputs[i + batchCounter].get(), dimension); - - stackInputs[batchCounter].input = - (scalar_in_t*)(inputs[i + batchCounter].get().const_data_ptr()); - stackInputs[batchCounter].offset = offset; - stackInputs[batchCounter].dimSize = dimSize; - stackInputs[batchCounter].nElements = - inputs[i + batchCounter].get().numel(); - - // update offset - offset += dimSize; + for (unsigned i = 0; i < inputs.size(); i += batch_size) { + for (batchCounter = 0; + batchCounter < batch_size && (i + batchCounter) < inputs.size(); + ++batchCounter) { + int64_t dimSize = 0; + // There is a legacy case where a 1-D empty tensor can be concat with + // high-dimensional tensor + if (inputs[i + batchCounter].get().numel() > 0) { + dimSize = inputs[i + batchCounter].get().size(dimension); } - q.memcpy((void*)d_inputs, (void*)stackInputs, tensorMetadataSize); - at::getHostAllocator(at::kXPU)->record_event( - (void*)stackInputs, - stackInputs_dptr.get_context(), - at::xpu::getCurrentXPUStream()); + catMetaData.input[batchCounter] = + (scalar_t*)(inputs[i + batchCounter].get().const_data_ptr()); + catMetaData.offset[batchCounter] = offset; + catMetaData.dimSize[batchCounter] = dimSize; + catMetaData.nElements[batchCounter] = + inputs[i + batchCounter].get().numel(); + + // If at least one of the inputs is not aligned, we can't call the + // CatArrayBatchedCopy_alignedK_contig + isAligned &= is_aligned_vec4(catMetaData.input[batchCounter]); + + if (stride_size > 1) { + auto strides = inputs[i + batchCounter].get().strides(); + auto sizes = inputs[i + batchCounter].get().sizes(); + for (int j = 0; j < nDims; j++) { + catMetaData.tensorStride[batchCounter].tensorSize[j] = sizes[j]; + catMetaData.tensorStride[batchCounter].tensorStride[j] = strides[j]; + } + catMetaData.isContiguous[batchCounter] = false; + isContig = false; + } else { + catMetaData.isContiguous[batchCounter] = true; + } + + // Update offset + offset += dimSize; + + // We need max elements per tensor to compute range parameters + max_elements_per_tensor = std::max( + max_elements_per_tensor, catMetaData.nElements[batchCounter]); } -#define HANDLE_CASE(DIMS) \ - CatArrayBatchedCopy< \ - scalar_out_t, \ - underlying_out_t, \ - scalar_in_t, \ - underlying_in_t, \ - unsigned int, \ - DIMS>( \ - data, \ - d_inputs, \ - param, \ - dimension, \ - param.outputStride[dimension], \ - batchCounter); + // Skip if the tensor is empty. Otherwise, the range dim is invalid + if (max_elements_per_tensor == 0) + continue; + + isContig = false; + sycl::range<2> applyGroup, catRange; + if (isContig && sizeof(scalar_t) > 2) { + std::tie(catRange, applyGroup) = + getCatRangeContig( + max_elements_per_tensor, batchCounter); + } else if (isContig && sizeof(scalar_t) == 2) { + std::tie(catRange, applyGroup) = + getCatRangeContig( + max_elements_per_tensor, batchCounter); + } else { + std::tie(catRange, applyGroup) = + getCatRange(max_elements_per_tensor, batchCounter); + } + + if (memory_format != c10::MemoryFormat::Contiguous) { + switch (dimension) { + case 0: + break; + case 1: + dimension = nDims - dimension; + break; + default: + dimension--; + } + } + +// Template Declarations for dim = 1, 2, 3, 4 +#define HANDLE_CASE(DIMS) \ + if (isContig && isAligned && sizeof(scalar_t) > 2 && \ + sizeof(scalar_t) <= 8) { \ + CatArrayBatchedCopy_alignedK_contig< \ + scalar_t, \ + unsigned int, \ + DIMS, \ + batch_size, \ + stride_size, \ + ALIGNED_VEC_LOAD_BYTES_16> \ + kfn(data, \ + catMetaData, \ + outputParam, \ + dimension, \ + outputParam.tensorStride[dimension]); \ + auto& q = getCurrentSYCLQueue(); \ + sycl_kernel_submit(catRange, applyGroup, q, kfn); \ + } else if (isContig && isAligned && sizeof(scalar_t) == 2) { \ + CatArrayBatchedCopy_alignedK_contig< \ + scalar_t, \ + unsigned int, \ + DIMS, \ + batch_size, \ + stride_size, \ + ALIGNED_VEC_LOAD_BYTES_8> \ + kfn(data, \ + catMetaData, \ + outputParam, \ + dimension, \ + outputParam.tensorStride[dimension]); \ + auto& q = getCurrentSYCLQueue(); \ + sycl_kernel_submit(catRange, applyGroup, q, kfn); \ + } else if (isContig) { \ + CatArrayBatchedCopy_contig< \ + scalar_t, \ + unsigned int, \ + DIMS, \ + batch_size, \ + stride_size> \ + kfn(data, \ + catMetaData, \ + outputParam, \ + dimension, \ + outputParam.tensorStride[dimension]); \ + auto& q = getCurrentSYCLQueue(); \ + sycl_kernel_submit(catRange, applyGroup, q, kfn); \ + } else { \ + CatArrayBatchedCopy \ + kfn(data, \ + catMetaData, \ + outputParam, \ + dimension, \ + outputParam.tensorStride[dimension]); \ + auto& q = getCurrentSYCLQueue(); \ + sycl_kernel_submit(catRange, applyGroup, q, kfn); \ + } + switch (nDims) { case 1: HANDLE_CASE(1); @@ -274,37 +524,22 @@ void parallel_cat( case 4: HANDLE_CASE(4); break; - case 5: - HANDLE_CASE(5); - break; - default: - break; } + #undef HANDLE_CASE } } -void check_shape_except_dim(Tensor& first, Tensor& second, int dimension) { - int first_dims = first.dim(); - int second_dims = second.dim(); - TORCH_CHECK( - first_dims == second_dims, "Tensors must have same number of dimensions"); - for (int dim = 0; dim < first_dims; dim++) { - if (dim == dimension) { - continue; - } - int64_t first_dim_size = first.size(dim); - int64_t second_dim_size = second.size(dim); - TORCH_CHECK( - first_dim_size == second_dim_size, - "Sizes of tensors must match except in dimension"); - } -} +// The kernels are templated on an opaque, self-aligned type of the correct +// size to avoid redundant kernels for different types of the same size. +template +struct alignas(N) OpaqueType { + char data[N]; +}; -// TODO: Evaluate latest PyTorch CUDA implementation for performance void cat_out_kernel( - const ITensorListRef& container, - int64_t dimension, + const ITensorListRef& tensors, + int64_t dim, int64_t valid, bool all_contiguous, bool all_same_dtype, @@ -315,94 +550,99 @@ void cat_out_kernel( return; } - MaterializedITensorListRef inputs = container.materialize(); - int numInputs = inputs.size(); - - int i, j; - int64_t offset; - bool hasSkippedInput = false; - Tensor notSkippedTensor; // non-owning reference - // empty tensor includes size[0], size[0, 0, ..., 0] (n-dim). - // here we only skip size[0], other empty sizes are not skipped. - auto should_skip = [](const Tensor& t) { - return t.numel() == 0 && t.dim() == 1; - }; - int nDims = 0; - - for (i = 0; i < numInputs; i++) { - if (should_skip(inputs[i])) { - hasSkippedInput = true; - continue; - } - nDims = inputs[i].get().dim(); - notSkippedTensor = inputs[i]; - } - - // If all inputs are empty tensors, return an empty tensor - if (!notSkippedTensor.defined()) { - return; - } - - TORCH_CHECK(numInputs > 0, "invalid number of inputs"); - TORCH_CHECK(dimension >= 0, "invalid dimension"); + auto materialized = tensors.materialize(); - Tensor first_tensor = inputs[0]; + // We parallelize the copy if all 6 conditions pass: + // + // 1. There is more than one input tensor + // 2. The out tensor is 32-bit indexable + // 3. The number of dimensions is <= 4 + // 4. All input tensors are contiguous (output tensor may be non-contig) + // 5. All input tensors can use 32-bit indexing - std::vector size(nDims); - - int64_t cat_dim_size = 0; - for (int i = 0; i < numInputs; i++) { - Tensor tensor = inputs[i]; - if (should_skip(tensor)) { - continue; - } - check_shape_except_dim(notSkippedTensor, tensor, dimension); - cat_dim_size += tensor.size(dimension); - } - - for (int dim = 0; dim < nDims; dim++) { - int64_t result_dim_size = notSkippedTensor.size(dim); - if (dim == dimension) { - result_dim_size = cat_dim_size; - } - size[dim] = result_dim_size; - } - - const bool all32BitIndexable = - std::all_of(inputs.begin(), inputs.end(), [](const Tensor& t) { + const bool all32BitIndexable = std::all_of( + materialized.begin(), materialized.end(), [](const Tensor& t) { return canUse32BitIndexMath(t); }); - const bool allContiguous = - std::all_of(inputs.begin(), inputs.end(), [](const Tensor& t) { - return !t.defined() || t.is_contiguous(); - }); - if (inputs.size() > 1 && !hasSkippedInput && - result.dim() <= CAT_ARRAY_MAX_INPUT_DIMS && - canUse32BitIndexMath(result) && allContiguous && all32BitIndexable && - all_same_dtype && - (inputs[0].get().scalar_type() == result.scalar_type())) { - AT_DISPATCH_V2( - result.scalar_type(), - "cat_xpu", - AT_WRAP([&]() { - parallel_cat( - result, inputs, dimension, nDims); - }), - AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), - kComplexHalf, - kHalf, - kBool, - kBFloat16, - AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); + int nDims = materialized[valid].get().dim(); + + // We support the contiguous inputs and non-contiguous input (<=4 dims) in + // different ways For contiguous input, we don't need to pass stride meta data + // to kernel through constant memory. Therefore, we could pass more inputs to + // threads. For non-contiguous, we reduce the number of inputs passed to + // kernel due to the limitation of constant memory. + + if (materialized.size() > 1 && result.dim() <= CAT_ARRAY_MAX_INPUT_DIMS && + canUse32BitIndexMath(result) && all_contiguous && all32BitIndexable && + all_same_dtype) { + if (isBitsType(result.scalar_type())) { + AT_DISPATCH_BIT_TYPES(result.scalar_type(), "cat_xpu", [&]() { + using dtype = OpaqueType; + parallel_cat( + result, materialized, dim, nDims, memory_format); + }); + } else { + AT_DISPATCH_V2( + result.scalar_type(), + "cat_xpu", + AT_WRAP([&]() { + using dtype = OpaqueType; + parallel_cat( + result, materialized, dim, nDims, memory_format); + }), + AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), + kComplexHalf, + kHalf, + kBool, + kBFloat16, + AT_EXPAND(AT_FLOAT8_TYPES), + AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), + kFloat4_e2m1fn_x2); + } + } else if ( + materialized.size() > 1 && result.dim() <= CAT_ARRAY_MAX_INPUT_DIMS && + canUse32BitIndexMath(result) && nDims <= CAT_ARRAY_MAX_INPUT_DIMS && + all32BitIndexable && all_same_dtype && + memory_format == c10::MemoryFormat::Contiguous) { + if (isBitsType(result.scalar_type())) { + AT_DISPATCH_BIT_TYPES(result.scalar_type(), "cat_xpu", [&]() { + using dtype = OpaqueType; + parallel_cat( + result, materialized, dim, nDims, memory_format); + }); + } else { + AT_DISPATCH_V2( + result.scalar_type(), + "cat_xpu", + AT_WRAP([&]() { + using dtype = OpaqueType; + parallel_cat< + dtype, + CAT_ARRAY_BATCH_SIZE / 2, + CAT_ARRAY_BATCH_SIZE / 2>( + result, materialized, dim, nDims, memory_format); + }), + AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), + kComplexHalf, + kHalf, + kBool, + kBFloat16, + kFloat8_e4m3fn, + kFloat8_e4m3fnuz, + kFloat8_e5m2, + kFloat8_e5m2fnuz, + AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), + kFloat4_e2m1fn_x2); + } } else { - offset = 0; - for (j = 0; j < numInputs; j++) { - if (should_skip(inputs[j])) + int64_t offset = 0; + for (const Tensor& t : materialized) { + if (cat_should_skip_tensor(t)) continue; - int64_t dimSize = inputs[j].get().size(dimension); - Tensor nt = at::narrow(result, dimension, offset, dimSize); - nt.copy_(inputs[j], false); + int64_t dimSize = t.size(dim); + Tensor nt = at::narrow(result, dim, offset, dimSize); + copy_(nt, t); offset += dimSize; } }