Skip to content

Commit deb7763

Browse files
rraminenpytorchmergebot
authored andcommitted
[ROCm] Reduce duplication in bfloat16_support_literal definition (pytorch#166147)
This PR refactors the bfloat16_support_literal constant in the PyTorch build logic to eliminate duplicated ROCm-specific code. Previously, there were two nearly identical branches for ROCM_VERSION < 70000 and ROCM_VERSION >= 70000, differing only by a single typedef. These have been unified into one conditional block with a minimal version guard inside. (#2502) Pull Request resolved: pytorch#166147 Approved by: https://github.com/jerrymannil, https://github.com/jeffdaily
1 parent d7040e6 commit deb7763

File tree

1 file changed

+7
-69
lines changed

1 file changed

+7
-69
lines changed

torch/csrc/jit/codegen/fuser/cuda/resource_strings.h

Lines changed: 7 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -260,82 +260,20 @@ typedef __half half;
260260
)";
261261
#endif
262262

263-
#if defined(USE_ROCM) && ROCM_VERSION < 70000
264-
constexpr auto bfloat16_support_literal =
265-
R"(
266-
#ifndef __align__
267-
#define __align__(x) __attribute__((aligned(x)))
268-
#endif
269-
270-
typedef struct __align__(2) {
271-
unsigned short x;
272-
}
273-
__nv_bfloat16_raw;
274-
275-
#if defined(__cplusplus)
276-
struct __align__(2) __nv_bfloat16 {
277-
__host__ __device__ __nv_bfloat16() {}
278-
279-
__host__ __device__ __nv_bfloat16& operator=(const __nv_bfloat16_raw& hr) {
280-
__x = hr.x;
281-
return *this;
282-
}
283-
284-
unsigned short __x;
285-
};
286-
287-
__device__ unsigned short __internal_float2bfloat16(
288-
const float f,
289-
unsigned int& sign,
290-
unsigned int& remainder) {
291-
unsigned int x;
292-
293-
x = __float_as_uint(f);
294-
295-
if ((x & 0x7fffffffU) > 0x7f800000U) {
296-
sign = 0U;
297-
remainder = 0U;
298-
return static_cast<unsigned short>(0x7fffU);
299-
}
300-
sign = x >> 31;
301-
remainder = x << 16;
302-
return static_cast<unsigned short>(x >> 16);
303-
}
263+
#if defined(USE_ROCM)
304264

305-
/* Definitions of intrinsics */
306-
__device__ __nv_bfloat16 __float2bfloat16(const float a) {
307-
__nv_bfloat16 val;
308-
__nv_bfloat16_raw r;
309-
unsigned int sign;
310-
unsigned int remainder;
311-
r.x = __internal_float2bfloat16(a, sign, remainder);
312-
if ((remainder > 0x80000000U) ||
313-
((remainder == 0x80000000U) && ((r.x & 0x1U) != 0U))) {
314-
r.x++;
315-
}
316-
val = r;
317-
return val;
318-
}
265+
#if ROCM_VERSION >= 70000
266+
#define BF16_UINT32_DEF "typedef unsigned int uint32_t;\n"
267+
#else
268+
#define BF16_UINT32_DEF ""
269+
#endif
319270

320-
__device__ float __bfloat162float(const __nv_bfloat16 a) {
321-
union
322-
{
323-
uint32_t int32;
324-
float fp32;
325-
} u = {uint32_t(a.__x) << 16};
326-
return u.fp32;
327-
}
328-
#endif /* defined(__cplusplus) */
329-
)";
330-
#elif defined(USE_ROCM) && ROCM_VERSION >= 70000
331271
constexpr auto bfloat16_support_literal =
332272
R"(
333273
#ifndef __align__
334274
#define __align__(x) __attribute__((aligned(x)))
335275
#endif
336-
337-
typedef unsigned int uint32_t;
338-
276+
)" BF16_UINT32_DEF R"(
339277
typedef struct __align__(2) {
340278
unsigned short x;
341279
}

0 commit comments

Comments
 (0)