Skip to content

Commit a92773e

Browse files
Revert "Use vectorized stores for all dtypes in cat (pytorch#161649)"
This reverts commit 3770337. Reverted pytorch#161649 on behalf of https://github.com/ngimel due to reverted internally ([comment](pytorch#161649 (comment)))
1 parent 53297f6 commit a92773e

File tree

2 files changed

+11
-139
lines changed

2 files changed

+11
-139
lines changed

aten/src/ATen/native/cuda/Shape.cu

Lines changed: 11 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -226,38 +226,6 @@ __global__ void CatArrayBatchedCopy_contig(
226226
}
227227
}
228228

229-
230-
template <typename T, typename IndexType, int Dims, int batch_size, int stride_size, int alignment, int elems_per_vec>
231-
__global__ void CatArrayBatchedCopy_vectorized(
232-
char* output,
233-
CatArrInputTensorMetadata<T, IndexType, batch_size, stride_size> inputs,
234-
TensorSizeStride<IndexType, CAT_ARRAY_MAX_INPUT_DIMS> os,
235-
const int concatDim,
236-
IndexType trailingSize) {
237-
238-
IndexType tid = blockIdx.x * blockDim.x + threadIdx.x;
239-
IndexType nElements = inputs.nElements[blockIdx.y] / elems_per_vec;
240-
241-
if(tid >= nElements) return;
242-
243-
const char * data = (char*)inputs.input[blockIdx.y];
244-
IndexType offset = inputs.offset[blockIdx.y] * trailingSize / elems_per_vec;
245-
IndexType dimSize = inputs.dimSize[blockIdx.y] * trailingSize / elems_per_vec;
246-
IndexType dataOffset = offset * alignment; // in bytes
247-
248-
IndexType stride = gridDim.x * blockDim.x;
249-
250-
while( tid < nElements){
251-
IndexType elementOffset = CatArrIndexToOffset<IndexType, Dims>::compute(
252-
os.tensorSize, os.tensorStride, dimSize, concatDim, tid) * alignment; // in bytes
253-
auto vec = at::native::memory::ld_vec<alignment>(data + alignment * tid);
254-
at::native::memory::st_vec<alignment>(output + dataOffset + elementOffset, vec);
255-
tid += stride;
256-
}
257-
}
258-
259-
260-
261229
/*
262230
Specialized implementation of the CatArrayBatchedCopy written to generate wide memory loads
263231
to improve memory bandwidth throughput.
@@ -328,27 +296,12 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
328296
scalar_t *data = (scalar_t *)(out.mutable_data_ptr());
329297
CatArrInputTensorMetadata<scalar_t, unsigned int, batch_size, stride_size> catMetaData;
330298
TensorSizeStride<unsigned int, CAT_ARRAY_MAX_INPUT_DIMS> outputParam;
331-
// If all batches are contiguous we can call a specialized implementation
332-
// which requires the input tensor addresses to be aligned to a
333-
// 16 Byte boundary.
334-
335-
constexpr bool isContig = stride_size == 1;
336-
bool isAligned = true;
337-
constexpr int alignment = 16;
338299

339300
// Next, let's initialize the size, stride arrays for the output Tensor.
340-
// for contig case, we'll canonicalize output strides, so that
341-
// we don't have arbitrary strides for dims of size 0
342-
size_t stride0 = 1;
343301
if (memory_format == c10::MemoryFormat::Contiguous) {
344-
for (int i = nDims - 1; i >= 0; --i) {
302+
for (int i = 0; i < nDims; ++i) {
345303
outputParam.tensorSize[i] = out.size(i);
346-
if (isContig) {
347-
outputParam.tensorStride[i] = stride0;
348-
stride0 *= out.size(i);
349-
} else {
350-
outputParam.tensorStride[i] = out.stride(i);
351-
}
304+
outputParam.tensorStride[i] = out.stride(i);
352305
}
353306
} else if (memory_format == c10::MemoryFormat::ChannelsLast || memory_format == c10::MemoryFormat::ChannelsLast3d) {
354307
// permute the semantics of dims from NCHW to NHWC so that the input
@@ -367,15 +320,12 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
367320

368321
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
369322

323+
// If all batches are contiguous we can call a specialized implementation
324+
// which requires the input tensor addresses to be aligned to a
325+
// 16 Byte boundary.
370326

371-
// for channels last computing slice size correctly is much more involved, so we never send it
372-
// on the fully vectorized path
373-
// we need output stride in cat dimension to be multiple of alignment,
374-
// if we ever use it to compute offsets
375-
// for catting in 0th dimension it doesn't matter
376-
bool isInOutAligned = isContig && at::native::memory::get_alignment(data) >= alignment &&
377-
memory_format == c10::MemoryFormat::Contiguous && (dimension == 0 ||
378-
outputParam.tensorStride[dimension - 1] * sizeof(scalar_t) % alignment == 0);
327+
bool isContig = true;
328+
bool isAligned = true;
379329
unsigned int max_elements_per_tensor = 0;
380330

381331
// Now we loop
@@ -391,16 +341,6 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
391341
// high-dimensional tensor
392342
if (inputs[i+batchCounter].get().numel() > 0) {
393343
dimSize = inputs[i+batchCounter].get().size(dimension);
394-
if (isInOutAligned) {
395-
auto t = inputs[i+batchCounter].get();
396-
// similarly to output stride, we cannot trust stride value to
397-
// determine slice size if the corresponding dimension is 1
398-
// we have to multiply all the subsequent sizes
399-
int64_t slice_size = dimension == 0 ? t.numel() : t.sizes()[dimension - 1] != 1 ?
400-
t.strides()[dimension - 1] : c10::multiply_integers(t.sizes().begin() + dimension, t.sizes().end());
401-
slice_size *= sizeof(scalar_t);
402-
isInOutAligned &= (slice_size % alignment == 0);
403-
}
404344
}
405345

406346
catMetaData.input[batchCounter] = (scalar_t*)(inputs[i+batchCounter].get().const_data_ptr());
@@ -411,12 +351,10 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
411351
#ifdef USE_ROCM
412352
// On ROCm, CatArrayBatchedCopy_contig is faster
413353
isAligned = false;
414-
isInOutAligned = false;
415354
#else
416355
// If at least one of the inputs is not aligned, we can't call the
417356
// CatArrayBatchedCopy_alignedK_contig
418357
isAligned &= is_aligned_vec4(catMetaData.input[batchCounter]);
419-
isInOutAligned &= at::native::memory::get_alignment(catMetaData.input[batchCounter]) >= alignment;
420358
#endif
421359

422360
if (stride_size > 1) {
@@ -427,6 +365,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
427365
catMetaData.tensorStride[batchCounter].tensorStride[j] = strides[j];
428366
}
429367
catMetaData.isContiguous[batchCounter] = false;
368+
isContig = false;
430369
} else {
431370
catMetaData.isContiguous[batchCounter] = true;
432371
}
@@ -449,44 +388,17 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
449388
max_elements_per_tensor, batchCounter);
450389
#else
451390
dim3 applyBlock, catGrid;
452-
if (isInOutAligned) {
453-
std::tie(catGrid, applyBlock) = getCatGridContig<scalar_t, alignment>(
454-
max_elements_per_tensor, batchCounter);
455-
} else if (isContig && isAligned && sizeof(scalar_t) > 2) {
391+
if (isContig && sizeof(scalar_t) > 2) {
456392
std::tie(catGrid, applyBlock) = getCatGridContig<scalar_t, ALIGNED_VEC_LOAD_BYTES_16>(
457393
max_elements_per_tensor, batchCounter);
458-
} else if (isContig && isAligned && sizeof(scalar_t) == 2) {
394+
} else if (isContig && sizeof(scalar_t) == 2) {
459395
std::tie(catGrid, applyBlock) = getCatGridContig<scalar_t, ALIGNED_VEC_LOAD_BYTES_8>(
460396
max_elements_per_tensor, batchCounter);
461397
} else {
462398
applyBlock = dim3(32 * 16);
463399
getCatGrid(batchCounter, catGrid);
464400
}
465401
#endif
466-
int32_t trailingSize;
467-
TensorSizeStride<unsigned int, CAT_ARRAY_MAX_INPUT_DIMS> kernelOutputParam;
468-
if (isInOutAligned) {
469-
// in this case we can and should flatten the tensors after the cat dim
470-
// we want to view the tensors as if consisting of `alignment`-sized elements
471-
// however, we might not be able to cleanly divide just the last dim -
472-
// it might not be the multiple of alignment.
473-
// however, we know that the full concatted slice is multiple of alignment,
474-
// so if we flatten all the dims after and including concat dim,
475-
// it will be divisible by alignment
476-
// then we need to divide last out size by elems_per_vec,
477-
// and divide all strides except last by elems_per_vec (last stride is 1 always)
478-
// for input, we will fix up the sizes and strides in the kernel directly
479-
kernelOutputParam = outputParam;
480-
nDims = dimension + 1;
481-
constexpr auto elems_per_vec = alignment / sizeof(scalar_t);
482-
auto out_size = dimension == 0 ? out.numel() : kernelOutputParam.tensorStride[dimension-1];
483-
kernelOutputParam.tensorSize[dimension] = out_size / elems_per_vec;
484-
trailingSize = outputParam.tensorStride[dimension];
485-
kernelOutputParam.tensorStride[dimension] = 1;
486-
for (int i = 0; i < dimension; ++i) {
487-
kernelOutputParam.tensorStride[i] /= elems_per_vec;
488-
}
489-
}
490402

491403
if (memory_format != c10::MemoryFormat::Contiguous) {
492404
switch (dimension) {
@@ -501,12 +413,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
501413
}
502414
// Template Declarations for dim = 1, 2, 3, 4
503415
#define HANDLE_CASE(DIMS) \
504-
if (isInOutAligned) {\
505-
constexpr auto elems_per_vec = alignment / sizeof(scalar_t); \
506-
CatArrayBatchedCopy_vectorized<scalar_t, unsigned int, DIMS, batch_size, stride_size, alignment, elems_per_vec><<<\
507-
catGrid, applyBlock, 0, stream.stream()>>>(\
508-
(char*)data, catMetaData, kernelOutputParam, dimension, trailingSize);\
509-
} else if (isContig && isAligned && sizeof(scalar_t) > 2 && sizeof(scalar_t) <= 8) {\
416+
if (isContig && isAligned && sizeof(scalar_t) > 2 && sizeof(scalar_t) <= 8) {\
510417
CatArrayBatchedCopy_alignedK_contig<scalar_t, unsigned int, DIMS, batch_size, stride_size, ALIGNED_VEC_LOAD_BYTES_16><<<\
511418
catGrid, applyBlock, 0, stream.stream()>>>(\
512419
data, catMetaData, outputParam, dimension, outputParam.tensorStride[dimension]);\

test/test_tensor_creation_ops.py

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1151,41 +1151,6 @@ def test_cat2(self, device, dtype):
11511151
z = torch.cat([x, y])
11521152
self.assertEqual(z.size(), (21, SIZE, SIZE))
11531153

1154-
@dtypes(torch.float)
1155-
def test_cat_size1(self, device, dtype):
1156-
# create a tensor that has aligned stride along dim - 1 dimension
1157-
# but catted slice size is not aligned
1158-
x1 = torch.randn(16, 16, device=device, dtype=dtype)[:1, :1]
1159-
xref = x1.clone().view(-1).view(x1.shape)
1160-
# make sure output size is aligned, need at least 4 elements for this
1161-
res = torch.cat([x1, x1, x1, x1], dim=-1)
1162-
ref = torch.cat([xref, xref, xref, xref], dim=-1)
1163-
self.assertEqual(res, ref)
1164-
1165-
@dtypes(torch.float)
1166-
def test_cat_trailing_dim(self, device, dtype):
1167-
x1 = torch.randn(16, 16, 23, device=device, dtype=dtype)
1168-
x2 = torch.rand_like(x1)
1169-
res = torch.cat([x1, x2], dim=1)
1170-
ref = torch.cat([x1.cpu(), x2.cpu()], dim=1)
1171-
self.assertEqual(res, ref)
1172-
1173-
@dtypes(torch.float)
1174-
def test_cat_misaligned(self, device, dtype):
1175-
x1 = torch.randn(14, device=device, dtype=dtype)[2:]
1176-
x2 = torch.rand_like(x1)
1177-
res = torch.cat([x1, x2], dim=-1)
1178-
ref = torch.cat([x1.cpu(), x2.cpu()], dim=-1)
1179-
self.assertEqual(res, ref)
1180-
1181-
@dtypes(torch.float)
1182-
def test_cat_multi_batch(self, device, dtype):
1183-
xs = [torch.randn(16, 16, device=device, dtype=dtype) for _ in range(130)]
1184-
xs_cpu = [x.cpu() for x in xs]
1185-
res = torch.cat(xs, dim=-1)
1186-
ref = torch.cat(xs_cpu, dim=-1)
1187-
self.assertEqual(res, ref)
1188-
11891154
# FIXME: Create an OpInfo-based tensor creation method test that verifies this for all tensor
11901155
# creation methods and verify all dtypes and layouts
11911156
@dtypes(torch.bool, torch.uint8, torch.int16, torch.int64, torch.float16, torch.float32, torch.complex64)

0 commit comments

Comments
 (0)