Skip to content

Commit 33e94c1

Browse files
committed
Add BF16->FP32 emulation
1 parent 9284e55 commit 33e94c1

File tree

3 files changed

+66
-36
lines changed

3 files changed

+66
-36
lines changed

tensorforge/backend/instructions/compute/primitives/amd.py

Lines changed: 47 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -392,34 +392,34 @@ def mfma_emu_int8(writer: Writer, C, B, A, c, a, b):
392392
c = writer.varalloc()
393393
writer(f'const auto {a} = static_cast<uint8x4_t>({Aa} % {x});')
394394
writer(f'const auto {b} = static_cast<uint8x4_t>({Ba} % {x});')
395-
writer(f'{c} = __builtin_amdgcn_mfma_i32_4x4x4i8(get_native_vector({a}), get_native_vector({b}), 0, {c}, {a}, {b});')
395+
writer(f'{c} = __builtin_amdgcn_mfma_i32_4x4x4i8({a}, {b}, 0, {c}, {a}, {b});')
396396
writer(f'{Ca} += {c} * {y};')
397397

398398
# TODO: scale back
399399

400400
def mfma_emu_bf16_f32(writer: Writer, C, B, A, c, a, b):
401+
writer(f'const auto [{A[0]}_p0, {A[0]}_p1, {A[0]}_p2] = tensorforge::splitFloatx4BF16({A[0]}, {A[1]}, {A[2]}, {A[3]});')
402+
writer(f'const auto [{B[0]}_p0, {B[0]}_p1, {B[0]}_p2] = tensorforge::splitFloatx4BF16({B[0]}, {B[1]}, {B[2]}, {B[3]});')
403+
writer(f'{C} = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k({A[0]}_p0, {B[0]}_p0, {C}, {c}, {a}, {b});')
404+
writer(f'{C} = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k({A[0]}_p0, {B[0]}_p1, {C}, {c}, {a}, {b});')
405+
writer(f'{C} = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k({A[0]}_p1, {B[0]}_p0, {C}, {c}, {a}, {b});')
406+
writer(f'{C} = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k({A[0]}_p0, {B[0]}_p2, {C}, {c}, {a}, {b});')
407+
writer(f'{C} = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k({A[0]}_p2, {B[0]}_p0, {C}, {c}, {a}, {b});')
408+
writer(f'{C} = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k({A[0]}_p1, {B[0]}_p1, {C}, {c}, {a}, {b});')
409+
410+
def mfma_emu_f16_f32(writer: Writer, C, B, A, c, a, b):
411+
Ar = writer.varalloc()
401412
A1 = writer.varalloc()
402413
A2 = writer.varalloc()
403-
A3 = writer.varalloc()
414+
Br = writer.varalloc()
404415
B1 = writer.varalloc()
405416
B2 = writer.varalloc()
406-
B3 = writer.varalloc()
407-
Ar = writer.varalloc()
408-
Br = writer.varalloc()
409-
writer(f'const bfloat16x4 {A1} = bfloat16x4({A});')
410-
writer(f'const bfloat16x4 {B1} = bfloat16x4({B});')
411-
writer(f'const bfloat16x4 {Ar} = {A} - {A1};')
412-
writer(f'const bfloat16x4 {Br} = {B} - {B1};')
413-
writer(f'const bfloat16x4 {A2} = bfloat16x4({Ar});')
414-
writer(f'const bfloat16x4 {B2} = bfloat16x4({Br});')
415-
writer(f'const bfloat16x4 {A3} = bfloat16x4({Ar} - {A2});')
416-
writer(f'const bfloat16x4 {B3} = bfloat16x4({Br} - {B2});')
417-
writer(f'{C} = __builtin_amdgcn_mfma_f32_4x4x4bf16(get_native_vector({A1}), get_native_vector({B1}), {C}, {c}, {a}, {b});')
418-
writer(f'{C} = __builtin_amdgcn_mfma_f32_4x4x4bf16(get_native_vector({A1}), get_native_vector({B2}), {C}, {c}, {a}, {b});')
419-
writer(f'{C} = __builtin_amdgcn_mfma_f32_4x4x4bf16(get_native_vector({A2}), get_native_vector({B1}), {C}, {c}, {a}, {b});')
420-
writer(f'{C} = __builtin_amdgcn_mfma_f32_4x4x4bf16(get_native_vector({A1}), get_native_vector({B3}), {C}, {c}, {a}, {b});')
421-
writer(f'{C} = __builtin_amdgcn_mfma_f32_4x4x4bf16(get_native_vector({A3}), get_native_vector({B1}), {C}, {c}, {a}, {b});')
422-
writer(f'{C} = __builtin_amdgcn_mfma_f32_4x4x4bf16(get_native_vector({A2}), get_native_vector({B2}), {C}, {c}, {a}, {b});')
417+
writer(f'const f16x4 {Ar} = f16x4({A});')
418+
writer(f'const f16x4 {Br} = f16x4({B});')
419+
writer(f'const f16x4 {A1} = f16x4({A});')
420+
writer(f'const f16x4 {B1} = f16x4({B});')
421+
writer(f'const f16x4 {A2} = f16x4({A} - {A1});')
422+
writer(f'const f16x4 {B2} = f16x4({B} - {B1});')
423423

424424
def matmul32(writer: Writer, C, B, A, M, N, K, kx, threads):
425425
with writer.AnonymousScope():
@@ -449,11 +449,13 @@ def write_matmul(block, start, cap):
449449
}[threads]
450450
}[block]
451451
fn = {
452+
1: f'fmacdpp16<0>()',
452453
4: '__builtin_amdgcn_mfma_f32_4x4x1f32',
453454
16: '__builtin_amdgcn_mfma_f32_16x16x1f32',
454455
32: '__builtin_amdgcn_mfma_f32_32x32x1f32'
455456
}[block]
456457
tp = {
458+
1: lambda tmpA: '',
457459
4: lambda tmpA: f'tensorforge::transpose4x4b32({tmpA}_0, {tmpA}_1, {tmpA}_2, {tmpA}_3, {tmpA}_0, {tmpA}_1, {tmpA}_2, {tmpA}_3)',
458460
16: lambda tmpA: f'tensorforge::transpose16x16b32({", ".join(f"{tmpA}_{i}" for i in range(16))})',
459461
32: lambda tmpA: f'tensorforge::transpose32x32b32({", ".join(f"{tmpA}_{i}" for i in range(32))})'
@@ -474,21 +476,33 @@ def write_matmul(block, start, cap):
474476
for i in range(0, M):
475477
with writer.AnonymousScope():
476478
writer(f'tensorforge::VectorT<float, {block}> {tmpacc}{"{}"};')
477-
for k in range(0, K, threads):
478-
dk = min(threads, K - k)
479+
for k in range(0, K + kx, threads):
480+
dk = min(threads, K + kx - k)
479481
for kk in range(0, dk, block):
480482
with writer.AnonymousScope():
481483
fB = [False] * block
482-
for kkk in range(min(block, dk - kk)):
484+
dkk = min(block, dk - kk)
485+
for kkk in range(dkk):
483486
fB[kkk] = B(writer, f'{tmpB}_{kkk}', i, k + kk + kkk)
484-
for kkk in range(min(block, dk - kk)):
485-
if fB[kkk]:
486-
trueK = k + kk + kkk + kx
487-
km = trueK // threads
488-
kkm = ((trueK % threads) // block)
489-
kkkm = trueK % block
490-
# the index for tmpB is correct
491-
writer(f'{tmpacc} = {fn}({tmpA}_{km}_{kkkm}, {tmpB}_{kkk}, {tmpacc}, {scale}, {kkm}, 0);')
487+
for kkk in range(dkk, block):
488+
writer(f'float {tmpB}_{kkk} = 0;')
489+
if True:
490+
Ar = [f'{tmpA}_{k // threads}_{kkk}' for kkk in range(4)]
491+
Br = [f'{tmpB}_{kkk}' for kkk in range(4)]
492+
mfma_emu_bf16_f32(writer, tmpacc, Br, Ar, scale, kk // 4, 0)
493+
else:
494+
for kkk in range(dkk):
495+
if fB[kkk]:
496+
trueK = k + kk + kkk #+ kx
497+
km = trueK // threads
498+
kkm = ((trueK % threads) // block)
499+
kkkm = trueK % block
500+
501+
assert km == k
502+
assert kkm == kk
503+
assert kkkm == kkk
504+
# the index for tmpB is correct
505+
writer(f'{tmpacc} = {fn}({tmpA}_{km}_{kkkm}, {tmpB}_{kkk}, {tmpacc}, {scale}, {kkm}, 0);')
492506

493507
for jj in range(min(block, N - j)):
494508
C(writer, f'{tmpacc}[{jj}]', i, j + jj)
@@ -500,7 +514,9 @@ def write_matmul(block, start, cap):
500514
#if N >= 16 and threads >= 16:
501515
# write_matmul(16, start, True)
502516
# start += (N // 16) * 16
503-
write_matmul(4, start, False)
517+
cap4 = False #N % 4 < 2
518+
write_matmul(4, start, cap4)
519+
# write_matmul(1, )
504520

505521
def fmadpp16(writer, C, A, B, row):
506522
writer(f'tensorforge::fmacdpp16<{row}>({C}, {A}, {B});')

tensorforge/backend/opt/optimizer.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,11 @@ def optimize(self):
3333
opt.apply()
3434
self._instrs = opt.get_instructions()
3535

36-
opt = MultiBuffer(self._context, self._instrs, self._shr_mem, self._scopes)
37-
opt.apply()
38-
self._instrs = opt.get_instructions()
39-
self._global_instrs = opt._global_instrs
36+
if self._context.get_vm().get_hw_descr().vendor == 'amd':
37+
opt = MultiBuffer(self._context, self._instrs, self._shr_mem, self._scopes)
38+
opt.apply()
39+
self._instrs = opt.get_instructions()
40+
self._global_instrs = opt._global_instrs
4041

4142
opt = PtrPipe(self._context, self._instrs)
4243
opt.apply()

tensorforge/include/tensorforge_device/hip.h

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -864,7 +864,8 @@ class Loader {
864864
};
865865
*/
866866

867-
std::tuple<short, short, short> splitFloatBF16(float input) {
867+
__device__ __forceinline__ std::tuple<short, short, short>
868+
splitFloatBF16(float input) {
868869
const auto i1 = static_cast<__bf16>(input);
869870
const auto i1r = input - static_cast<float>(i1);
870871
const auto i2 = static_cast<__bf16>(i1r);
@@ -876,4 +877,16 @@ std::tuple<short, short, short> splitFloatBF16(float input) {
876877
return {r1, r2, r3};
877878
}
878879

880+
__device__ __forceinline__
881+
std::tuple<VectorT<short, 4>, VectorT<short, 4>, VectorT<short, 4>>
882+
splitFloatx4BF16(float i1, float i2, float i3, float i4) {
883+
const auto [i1p0, i1p1, i1p2] = splitFloatBF16(i1);
884+
const auto [i2p0, i2p1, i2p2] = splitFloatBF16(i2);
885+
const auto [i3p0, i3p1, i3p2] = splitFloatBF16(i3);
886+
const auto [i4p0, i4p1, i4p2] = splitFloatBF16(i4);
887+
return {VectorT<short, 4>{i1p0, i2p0, i3p0, i4p0},
888+
VectorT<short, 4>{i1p1, i2p1, i3p1, i4p1},
889+
VectorT<short, 4>{i1p2, i2p2, i3p2, i4p2}};
890+
}
891+
879892
} // namespace tensorforge

0 commit comments

Comments
 (0)