Skip to content

Commit 846f6b4

Browse files
rraminenpragupta
authored andcommitted
[release/2.8] Define uint32 t when ROCM_VERSION >= 70000 (#2513)
Fixes SWDEV-543698 (https://ontrack-internal.amd.com/browse/SWDEV-543698) Cherry-picked from #2502 This PR fixes the errors like below: ``` [rank3]: RuntimeError: The following operation failed in the TorchScript interpreter. [rank3]: Traceback of TorchScript (most recent call last): [rank3]: RuntimeError: /tmp/comgr-28f951/input/CompileSourceACC062:67:7: error: unknown type name 'uint32_t'; did you mean '__hip_internal::uint32_t'? [rank3]: 67 | uint32_t int32; [rank3]: | ^~~~~~~~ [rank3]: | __hip_internal::uint32_t ``` Earlier uint32_t was defined in HIP headers in std namespace. Now it is moved to __hip_internal namespace in hip headers. This change is made in ROCm 7.0. (cherry picked from commit b2fb688)
1 parent dec58b7 commit 846f6b4

File tree

1 file changed

+70
-1
lines changed

1 file changed

+70
-1
lines changed

torch/csrc/jit/codegen/fuser/cuda/resource_strings.h

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ typedef __half half;
260260
)";
261261
#endif
262262

263-
#if defined(USE_ROCM)
263+
#if defined(USE_ROCM) && ROCM_VERSION < 70000
264264
constexpr auto bfloat16_support_literal =
265265
R"(
266266
#ifndef __align__
@@ -317,6 +317,75 @@ __device__ __nv_bfloat16 __float2bfloat16(const float a) {
317317
return val;
318318
}
319319
320+
__device__ float __bfloat162float(const __nv_bfloat16 a) {
321+
union
322+
{
323+
uint32_t int32;
324+
float fp32;
325+
} u = {uint32_t(a.__x) << 16};
326+
return u.fp32;
327+
}
328+
#endif /* defined(__cplusplus) */
329+
)";
330+
#elif defined(USE_ROCM) && ROCM_VERSION >= 70000
331+
constexpr auto bfloat16_support_literal =
332+
R"(
333+
#ifndef __align__
334+
#define __align__(x) __attribute__((aligned(x)))
335+
#endif
336+
337+
typedef unsigned int uint32_t;
338+
339+
typedef struct __align__(2) {
340+
unsigned short x;
341+
}
342+
__nv_bfloat16_raw;
343+
344+
#if defined(__cplusplus)
345+
struct __align__(2) __nv_bfloat16 {
346+
__host__ __device__ __nv_bfloat16() {}
347+
348+
__host__ __device__ __nv_bfloat16& operator=(const __nv_bfloat16_raw& hr) {
349+
__x = hr.x;
350+
return *this;
351+
}
352+
353+
unsigned short __x;
354+
};
355+
356+
__device__ unsigned short __internal_float2bfloat16(
357+
const float f,
358+
unsigned int& sign,
359+
unsigned int& remainder) {
360+
unsigned int x;
361+
362+
x = __float_as_uint(f);
363+
364+
if ((x & 0x7fffffffU) > 0x7f800000U) {
365+
sign = 0U;
366+
remainder = 0U;
367+
return static_cast<unsigned short>(0x7fffU);
368+
}
369+
sign = x >> 31;
370+
remainder = x << 16;
371+
return static_cast<unsigned short>(x >> 16);
372+
}
373+
374+
/* Definitions of intrinsics */
375+
__device__ __nv_bfloat16 __float2bfloat16(const float a) {
376+
__nv_bfloat16 val;
377+
__nv_bfloat16_raw r;
378+
unsigned int sign;
379+
unsigned int remainder;
380+
r.x = __internal_float2bfloat16(a, sign, remainder);
381+
if ((remainder > 0x80000000U) ||
382+
((remainder == 0x80000000U) && ((r.x & 0x1U) != 0U))) {
383+
r.x++;
384+
}
385+
val = r;
386+
return val;
387+
}
388+
320389
__device__ float __bfloat162float(const __nv_bfloat16 a) {
321390
union
322391
{

0 commit comments

Comments
 (0)