Skip to content

Commit 5b10b0a

Browse files
drdarshanpytorchmergebot
authored andcommitted
Slightly improve error message from repeat_interleave kernel (pytorch#157996)
Summary: In many investigations relating to invalid feature values, the three-argument form of `repeat_interleave` currently prints the following message if there is an inconsistency between `sum(repeats)` and `output_size`: ``` Assertion `result_size == cumsum_ptr[size - 1]` failed. ``` This is a bit hard for model authors to understand so I made the error slightly more comprehensible. After the fix the stdout contains the actual values of these parameters: https://fburl.com/mlhub/cfyyhh3q ``` Invalid input! In `repeat_interleave`, the `output_size` argument (949487) must be the same as the sum of the elements in the `repeats` tensor (949687). ``` In many cases, this is potentially useful information since we know for example that the difference between the two values above (949687-949487=200) happens to be the lengths of one of the features. ## What are my concerns with this change? 1. Outputs from `__assert_fail` go to `stderr` whereas `printf` writes to `stdout`. This is not the usual debugging flow where all logs can be found in `stderr`. I could not find a way to redirect `printf` to stderr or `__assert_fail` to stdout 2. Two checks happen instead of one in the error path. I wanted to preserve the semantics of what happens inside `__assert_fail`. 3. I have not seen this pattern in other PyTorch kernels but `repeat_interleave` with three arguments seems special in other ways too. Test Plan: * Built an ephemeral package with my changes: https://www.internalfb.com/intern/servicelab/build/736441058/ * Verified that a job with these changes indeed prints out the expected message to stdout: https://fburl.com/mlhub/jgbqk8eg * I will export to GH and run CI/CD tests. Rollback Plan: steps: - manual.note: content: >- Just reverting this diff should be sufficient. Since this change is in CUDA kernels, I do not believe there is a way to change the error message via a JK. Reviewed By: mradmila Differential Revision: D77904753 Pull Request resolved: pytorch#157996 Approved by: https://github.com/ngimel, https://github.com/eqy
1 parent fb462ce commit 5b10b0a

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

aten/src/ATen/native/cuda/Repeat.cu

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,13 @@ __global__ static void compute_cuda_kernel(
1717
index_t* result_ptr,
1818
int64_t size,
1919
int64_t result_size) {
20-
CUDA_KERNEL_ASSERT(result_size == cumsum_ptr[size - 1]);
20+
if (C10_UNLIKELY((result_size != cumsum_ptr[size - 1]))) {
21+
printf("%s:%d:%s: block: [%d,%d,%d], thread: [%d,%d,%d] "
22+
"Invalid input! In `repeat_interleave`, the `output_size` argument (%ld) must be the same as the sum of the elements in the `repeats` tensor (%ld).\n",
23+
__FILE__, __LINE__, __func__,blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, threadIdx.z, result_size, cumsum_ptr[size - 1 ]);
24+
CUDA_KERNEL_ASSERT(result_size == cumsum_ptr[size - 1])
25+
}
26+
2127
int64_t idx = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x;
2228
int64_t stride = (blockDim.x * gridDim.x) / C10_WARP_SIZE;
2329
int warp_id = idx / C10_WARP_SIZE;

0 commit comments

Comments
 (0)