Skip to content

Commit 524146e

Browse files
committed
Try FP16 emulation
1 parent 33e94c1 commit 524146e

File tree

2 files changed

+29
-13
lines changed
  • tensorforge
    • backend/instructions/compute/primitives
    • include/tensorforge_device

2 files changed

+29
-13
lines changed

tensorforge/backend/instructions/compute/primitives/amd.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -408,18 +408,11 @@ def mfma_emu_bf16_f32(writer: Writer, C, B, A, c, a, b):
408408
writer(f'{C} = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k({A[0]}_p1, {B[0]}_p1, {C}, {c}, {a}, {b});')
409409

410410
def mfma_emu_f16_f32(writer: Writer, C, B, A, c, a, b):
411-
Ar = writer.varalloc()
412-
A1 = writer.varalloc()
413-
A2 = writer.varalloc()
414-
Br = writer.varalloc()
415-
B1 = writer.varalloc()
416-
B2 = writer.varalloc()
417-
writer(f'const f16x4 {Ar} = f16x4({A});')
418-
writer(f'const f16x4 {Br} = f16x4({B});')
419-
writer(f'const f16x4 {A1} = f16x4({A});')
420-
writer(f'const f16x4 {B1} = f16x4({B});')
421-
writer(f'const f16x4 {A2} = f16x4({A} - {A1});')
422-
writer(f'const f16x4 {B2} = f16x4({B} - {B1});')
411+
writer(f'const auto [{A[0]}_p0, {A[0]}_p1] = tensorforge::splitFloatx4F16({A[0]}, {A[1]}, {A[2]}, {A[3]});')
412+
writer(f'const auto [{B[0]}_p0, {B[0]}_p1] = tensorforge::splitFloatx4F16({B[0]}, {B[1]}, {B[2]}, {B[3]});')
413+
writer(f'{C} = __builtin_amdgcn_mfma_f32_4x4x4f16({A[0]}_p0, {B[0]}_p0, {C}, {c}, {a}, {b});')
414+
writer(f'{C} = __builtin_amdgcn_mfma_f32_4x4x4f16({A[0]}_p1, {B[0]}_p0, {C}, {c}, {a}, {b});')
415+
writer(f'{C} = __builtin_amdgcn_mfma_f32_4x4x4f16({A[0]}_p0, {B[0]}_p1, {C}, {c}, {a}, {b});')
423416

424417
def matmul32(writer: Writer, C, B, A, M, N, K, kx, threads):
425418
with writer.AnonymousScope():
@@ -486,10 +479,14 @@ def write_matmul(block, start, cap):
486479
fB[kkk] = B(writer, f'{tmpB}_{kkk}', i, k + kk + kkk)
487480
for kkk in range(dkk, block):
488481
writer(f'float {tmpB}_{kkk} = 0;')
489-
if True:
482+
if False:
490483
Ar = [f'{tmpA}_{k // threads}_{kkk}' for kkk in range(4)]
491484
Br = [f'{tmpB}_{kkk}' for kkk in range(4)]
492485
mfma_emu_bf16_f32(writer, tmpacc, Br, Ar, scale, kk // 4, 0)
486+
elif True:
487+
Ar = [f'{tmpA}_{k // threads}_{kkk}' for kkk in range(4)]
488+
Br = [f'{tmpB}_{kkk}' for kkk in range(4)]
489+
mfma_emu_f16_f32(writer, tmpacc, Br, Ar, scale, kk // 4, 0)
493490
else:
494491
for kkk in range(dkk):
495492
if fB[kkk]:

tensorforge/include/tensorforge_device/hip.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -889,4 +889,23 @@ __device__ __forceinline__
889889
VectorT<short, 4>{i1p2, i2p2, i3p2, i4p2}};
890890
}
891891

892+
__device__ __forceinline__ std::tuple<_Float16, _Float16>
893+
splitFloatF16(float input) {
894+
const auto i1 = static_cast<_Float16>(input);
895+
const auto i1r = input - static_cast<float>(i1);
896+
const auto i2 = static_cast<_Float16>(i1r);
897+
return {i1, i2};
898+
}
899+
900+
__device__
901+
__forceinline__ std::tuple<VectorT<_Float16, 4>, VectorT<_Float16, 4>>
902+
splitFloatx4F16(float i1, float i2, float i3, float i4) {
903+
const auto [i1p0, i1p1] = splitFloatF16(i1);
904+
const auto [i2p0, i2p1] = splitFloatF16(i2);
905+
const auto [i3p0, i3p1] = splitFloatF16(i3);
906+
const auto [i4p0, i4p1] = splitFloatF16(i4);
907+
return {VectorT<_Float16, 4>{i1p0, i2p0, i3p0, i4p0},
908+
VectorT<_Float16, 4>{i1p1, i2p1, i3p1, i4p1}};
909+
}
910+
892911
} // namespace tensorforge

0 commit comments

Comments
 (0)