@@ -359,16 +359,15 @@ static inline int C10_WARP_SIZE_INTERNAL() {
359359// Those platforms do not support assert()
360360#define CUDA_KERNEL_ASSERT (cond )
361361#define CUDA_KERNEL_ASSERT_MSG (cond, msg )
362+ #define CUDA_KERNEL_ASSERT_PRINTF (cond, msg, ...)
362363#define SYCL_KERNEL_ASSERT (cond )
363364#elif defined(_MSC_VER)
364365#if defined(NDEBUG)
365366extern " C" {
366367C10_IMPORT
367368#if defined(__SYCL_DEVICE_ONLY__)
368- extern SYCL_EXTERNAL void _wassert (
369- const wchar_t * wexpr,
370- const wchar_t * wfile,
371- unsigned line);
369+ extern SYCL_EXTERNAL void
370+ _wassert (const wchar_t * wexpr, const wchar_t * wfile, unsigned line);
372371#else
373372#if defined(__CUDA_ARCH__)
374373__host__ __device__
@@ -396,6 +395,26 @@ __host__ __device__
396395 static_cast <unsigned >(__LINE__)), \
397396 0 ); \
398397 }
398+ #define CUDA_KERNEL_ASSERT_PRINTF (cond, msg, ...) \
399+ if (C10_UNLIKELY(!(cond))) { \
400+ (void )(printf ( \
401+ " [CUDA_KERNEL_ASSERT] " __FILE__ " :" C10_STRINGIZE ( \
402+ __LINE__) " : %s: block: [%d,%d,%d], thread: [%d,%d,%d]: " \
403+ " Assertion failed: `" #cond " `: " msg " \n " , \
404+ __func__, \
405+ blockIdx.x , \
406+ blockIdx.y , \
407+ blockIdx.z , \
408+ threadIdx.x , \
409+ threadIdx.y , \
410+ threadIdx.z , \
411+ ##__VA_ARGS__)); \
412+ (void )(_wassert ( \
413+ _CRT_WIDE (#cond), \
414+ _CRT_WIDE (__FILE__), \
415+ static_cast <unsigned >(__LINE__)), \
416+ 0 ); \
417+ }
399418#define SYCL_KERNEL_ASSERT (cond ) \
400419 if (C10_UNLIKELY(!(cond))) { \
401420 (void )(_wassert ( \
@@ -415,11 +434,8 @@ extern SYCL_EXTERNAL void __assert_fail(
415434 const char * func);
416435#elif (defined(__EMSCRIPTEN__))
417436// As defined in assert.h in the Emscripten stdlib
418- _Noreturn void __assert_fail (
419- const char * expr,
420- const char * file,
421- int line,
422- const char * func);
437+ _Noreturn void
438+ __assert_fail (const char * expr, const char * file, int line, const char * func);
423439#else // __SYCL_DEVICE_ONLY__
424440#if (defined(__CUDA_ARCH__) && !(defined(__clang__) && defined(__CUDA__)))
425441// CUDA supports __assert_fail function which are common for both device
@@ -455,6 +471,10 @@ __host__ __device__
455471 if C10_UNLIKELY (!(cond)) { \
456472 abort (); \
457473 }
474+ #define CUDA_KERNEL_ASSERT_PRINTF (cond, msg, ...) \
475+ if C10_UNLIKELY (!(cond)) { \
476+ abort (); \
477+ }
458478#define SYCL_KERNEL_ASSERT (cond ) \
459479 if C10_UNLIKELY (!(cond)) { \
460480 abort (); \
@@ -470,6 +490,23 @@ __host__ __device__
470490 __assert_fail ( \
471491 msg, __FILE__, static_cast <unsigned int >(__LINE__), __func__); \
472492 }
493+ #define CUDA_KERNEL_ASSERT_PRINTF (cond, msg, ...) \
494+ if (C10_UNLIKELY(!(cond))) { \
495+ printf ( \
496+ " [CUDA_KERNEL_ASSERT] " __FILE__ " :" C10_STRINGIZE ( \
497+ __LINE__) " : %s: block: [%d,%d,%d], thread: [%d,%d,%d]: " \
498+ " Assertion failed: `" #cond " `: " msg " \n " , \
499+ __func__, \
500+ blockIdx.x , \
501+ blockIdx.y , \
502+ blockIdx.z , \
503+ threadIdx.x , \
504+ threadIdx.y , \
505+ threadIdx.z , \
506+ ##__VA_ARGS__); \
507+ __assert_fail ( \
508+ #cond, __FILE__, static_cast <unsigned int >(__LINE__), __func__); \
509+ }
473510#define SYCL_KERNEL_ASSERT (cond ) \
474511 if (C10_UNLIKELY(!(cond))) { \
475512 __assert_fail ( \
0 commit comments