@@ -226,38 +226,6 @@ __global__ void CatArrayBatchedCopy_contig(
226226 }
227227}
228228
229-
230- template <typename T, typename IndexType, int Dims, int batch_size, int stride_size, int alignment, int elems_per_vec>
231- __global__ void CatArrayBatchedCopy_vectorized (
232- char * output,
233- CatArrInputTensorMetadata<T, IndexType, batch_size, stride_size> inputs,
234- TensorSizeStride<IndexType, CAT_ARRAY_MAX_INPUT_DIMS> os,
235- const int concatDim,
236- IndexType trailingSize) {
237-
238- IndexType tid = blockIdx .x * blockDim .x + threadIdx .x ;
239- IndexType nElements = inputs.nElements [blockIdx .y ] / elems_per_vec;
240-
241- if (tid >= nElements) return ;
242-
243- const char * data = (char *)inputs.input [blockIdx .y ];
244- IndexType offset = inputs.offset [blockIdx .y ] * trailingSize / elems_per_vec;
245- IndexType dimSize = inputs.dimSize [blockIdx .y ] * trailingSize / elems_per_vec;
246- IndexType dataOffset = offset * alignment; // in bytes
247-
248- IndexType stride = gridDim .x * blockDim .x ;
249-
250- while ( tid < nElements){
251- IndexType elementOffset = CatArrIndexToOffset<IndexType, Dims>::compute (
252- os.tensorSize , os.tensorStride , dimSize, concatDim, tid) * alignment; // in bytes
253- auto vec = at::native::memory::ld_vec<alignment>(data + alignment * tid);
254- at::native::memory::st_vec<alignment>(output + dataOffset + elementOffset, vec);
255- tid += stride;
256- }
257- }
258-
259-
260-
261229/*
262230 Specialized implementation of the CatArrayBatchedCopy written to generate wide memory loads
263231 to improve memory bandwidth throughput.
@@ -328,27 +296,12 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
328296 scalar_t *data = (scalar_t *)(out.mutable_data_ptr ());
329297 CatArrInputTensorMetadata<scalar_t , unsigned int , batch_size, stride_size> catMetaData;
330298 TensorSizeStride<unsigned int , CAT_ARRAY_MAX_INPUT_DIMS> outputParam;
331- // If all batches are contiguous we can call a specialized implementation
332- // which requires the input tensor addresses to be aligned to a
333- // 16 Byte boundary.
334-
335- constexpr bool isContig = stride_size == 1 ;
336- bool isAligned = true ;
337- constexpr int alignment = 16 ;
338299
339300 // Next, let's initialize the size, stride arrays for the output Tensor.
340- // for contig case, we'll canonicalize output strides, so that
341- // we don't have arbitrary strides for dims of size 0
342- size_t stride0 = 1 ;
343301 if (memory_format == c10::MemoryFormat::Contiguous) {
344- for (int i = nDims - 1 ; i >= 0 ; -- i) {
302+ for (int i = 0 ; i < nDims; ++ i) {
345303 outputParam.tensorSize [i] = out.size (i);
346- if (isContig) {
347- outputParam.tensorStride [i] = stride0;
348- stride0 *= out.size (i);
349- } else {
350- outputParam.tensorStride [i] = out.stride (i);
351- }
304+ outputParam.tensorStride [i] = out.stride (i);
352305 }
353306 } else if (memory_format == c10::MemoryFormat::ChannelsLast || memory_format == c10::MemoryFormat::ChannelsLast3d) {
354307 // permute the semantics of dims from NCHW to NHWC so that the input
@@ -367,15 +320,12 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
367320
368321 at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream ();
369322
323+ // If all batches are contiguous we can call a specialized implementation
324+ // which requires the input tensor addresses to be aligned to a
325+ // 16 Byte boundary.
370326
371- // for channels last computing slice size correctly is much more involved, so we never send it
372- // on the fully vectorized path
373- // we need output stride in cat dimension to be multiple of alignment,
374- // if we ever use it to compute offsets
375- // for catting in 0th dimension it doesn't matter
376- bool isInOutAligned = isContig && at::native::memory::get_alignment (data) >= alignment &&
377- memory_format == c10::MemoryFormat::Contiguous && (dimension == 0 ||
378- outputParam.tensorStride [dimension - 1 ] * sizeof (scalar_t ) % alignment == 0 );
327+ bool isContig = true ;
328+ bool isAligned = true ;
379329 unsigned int max_elements_per_tensor = 0 ;
380330
381331 // Now we loop
@@ -391,16 +341,6 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
391341 // high-dimensional tensor
392342 if (inputs[i+batchCounter].get ().numel () > 0 ) {
393343 dimSize = inputs[i+batchCounter].get ().size (dimension);
394- if (isInOutAligned) {
395- auto t = inputs[i+batchCounter].get ();
396- // similarly to output stride, we cannot trust stride value to
397- // determine slice size if the corresponding dimension is 1
398- // we have to multiply all the subsequent sizes
399- int64_t slice_size = dimension == 0 ? t.numel () : t.sizes ()[dimension - 1 ] != 1 ?
400- t.strides ()[dimension - 1 ] : c10::multiply_integers (t.sizes ().begin () + dimension, t.sizes ().end ());
401- slice_size *= sizeof (scalar_t );
402- isInOutAligned &= (slice_size % alignment == 0 );
403- }
404344 }
405345
406346 catMetaData.input [batchCounter] = (scalar_t *)(inputs[i+batchCounter].get ().const_data_ptr ());
@@ -411,12 +351,10 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
411351#ifdef USE_ROCM
412352 // On ROCm, CatArrayBatchedCopy_contig is faster
413353 isAligned = false ;
414- isInOutAligned = false ;
415354#else
416355 // If at least one of the inputs is not aligned, we can't call the
417356 // CatArrayBatchedCopy_alignedK_contig
418357 isAligned &= is_aligned_vec4 (catMetaData.input [batchCounter]);
419- isInOutAligned &= at::native::memory::get_alignment (catMetaData.input [batchCounter]) >= alignment;
420358#endif
421359
422360 if (stride_size > 1 ) {
@@ -427,6 +365,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
427365 catMetaData.tensorStride [batchCounter].tensorStride [j] = strides[j];
428366 }
429367 catMetaData.isContiguous [batchCounter] = false ;
368+ isContig = false ;
430369 } else {
431370 catMetaData.isContiguous [batchCounter] = true ;
432371 }
@@ -449,44 +388,17 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
449388 max_elements_per_tensor, batchCounter);
450389#else
451390 dim3 applyBlock, catGrid;
452- if (isInOutAligned) {
453- std::tie (catGrid, applyBlock) = getCatGridContig<scalar_t , alignment>(
454- max_elements_per_tensor, batchCounter);
455- } else if (isContig && isAligned && sizeof (scalar_t ) > 2 ) {
391+ if (isContig && sizeof (scalar_t ) > 2 ) {
456392 std::tie (catGrid, applyBlock) = getCatGridContig<scalar_t , ALIGNED_VEC_LOAD_BYTES_16>(
457393 max_elements_per_tensor, batchCounter);
458- } else if (isContig && isAligned && sizeof (scalar_t ) == 2 ) {
394+ } else if (isContig && sizeof (scalar_t ) == 2 ) {
459395 std::tie (catGrid, applyBlock) = getCatGridContig<scalar_t , ALIGNED_VEC_LOAD_BYTES_8>(
460396 max_elements_per_tensor, batchCounter);
461397 } else {
462398 applyBlock = dim3 (32 * 16 );
463399 getCatGrid (batchCounter, catGrid);
464400 }
465401#endif
466- int32_t trailingSize;
467- TensorSizeStride<unsigned int , CAT_ARRAY_MAX_INPUT_DIMS> kernelOutputParam;
468- if (isInOutAligned) {
469- // in this case we can and should flatten the tensors after the cat dim
470- // we want to view the tensors as if consisting of `alignment`-sized elements
471- // however, we might not be able to cleanly divide just the last dim -
472- // it might not be the multiple of alignment.
473- // however, we know that the full concatted slice is multiple of alignment,
474- // so if we flatten all the dims after and including concat dim,
475- // it will be divisible by alignment
476- // then we need to divide last out size by elems_per_vec,
477- // and divide all strides except last by elems_per_vec (last stride is 1 always)
478- // for input, we will fix up the sizes and strides in the kernel directly
479- kernelOutputParam = outputParam;
480- nDims = dimension + 1 ;
481- constexpr auto elems_per_vec = alignment / sizeof (scalar_t );
482- auto out_size = dimension == 0 ? out.numel () : kernelOutputParam.tensorStride [dimension-1 ];
483- kernelOutputParam.tensorSize [dimension] = out_size / elems_per_vec;
484- trailingSize = outputParam.tensorStride [dimension];
485- kernelOutputParam.tensorStride [dimension] = 1 ;
486- for (int i = 0 ; i < dimension; ++i) {
487- kernelOutputParam.tensorStride [i] /= elems_per_vec;
488- }
489- }
490402
491403 if (memory_format != c10::MemoryFormat::Contiguous) {
492404 switch (dimension) {
@@ -501,12 +413,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
501413 }
502414 // Template Declarations for dim = 1, 2, 3, 4
503415#define HANDLE_CASE (DIMS ) \
504- if (isInOutAligned) {\
505- constexpr auto elems_per_vec = alignment / sizeof (scalar_t ); \
506- CatArrayBatchedCopy_vectorized<scalar_t , unsigned int , DIMS, batch_size, stride_size, alignment, elems_per_vec><<<\
507- catGrid, applyBlock, 0 , stream.stream()>>> (\
508- (char *)data, catMetaData, kernelOutputParam, dimension, trailingSize);\
509- } else if (isContig && isAligned && sizeof (scalar_t ) > 2 && sizeof (scalar_t ) <= 8 ) {\
416+ if (isContig && isAligned && sizeof (scalar_t ) > 2 && sizeof (scalar_t ) <= 8 ) {\
510417 CatArrayBatchedCopy_alignedK_contig<scalar_t , unsigned int , DIMS, batch_size, stride_size, ALIGNED_VEC_LOAD_BYTES_16><<<\
511418 catGrid, applyBlock, 0 , stream.stream()>>> (\
512419 data, catMetaData, outputParam, dimension, outputParam.tensorStride [dimension]);\
0 commit comments