You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
**Update:** I have found that for better perf, we need to use 3-6 BF16
dot products but not more. My findings are at:
https://gist.github.com/plotfi/72554bd410ea55d8ae67b501c69b2766
The short version is that the Triton Bench tutorial matmul with F32
benefits by 60-70% using 3 BF16 dots or 10-15% using 6 BF16 dots.
I think this is sufficient to move forward as a replacement for MI350s
TF32 and is in line with what hipblas does:
https://github.com/ROCm/rocm-libraries/blob/develop/projects/hipblaslt/tensilelite/Tensile/Components/LocalRead.py#L288-L330
There is a similar implementation in XLA as well:
https://github.com/openxla/xla/blob/e33f93fb7220d408811afdc926cf10baaf49c64e/xla/backends/gpu/codegen/triton/dot_algorithms.cc#L152
--------
Implements emulation of a 32-bit floating point dot operation using 3
BF16s. This is based on https://arxiv.org/abs/1904.06376 and works
because the mantisa of 3 BF16s add up to the mantisa of a fp32.
Storing 1 fp32 in 3 bf16s:
```python
def BF16(v):
return v.to(torch.bfloat16)
def FP32(v):
return v.to(torch.float32)
def BF16x3(v):
b0 = BF16(original)
b1 = BF16(original - FP32(b0))
b2 = BF16(original - FP32(b0) - FP32(b1))
return (b0, b1, b2)
original = torch.rand(1, 1, dtype=torch.float32)
bf16x3 = BF16x3(original)
```
Emulating multiplication of two fp32s:
```python
def mul_bf16x3(a, b, c):
a0, a1, a2 = BF16x3(a)
b0, b1, b2 = BF16x3(b)
c = c + (a0 * b0) # low low
c = c + (a1 * b0) # mid low
c = c + (a0 * b1) # low mid
c = c + (a1 * b1) # mid mid
c = c + (a0 * b2) # low hi
c = c + (a2 * b0) # hi low
c = c + (a1 * b2) # mid hi
c = c + (a2 * b1) # hi mid
c = c + (a2 * b2) # hi hi
return c
a = torch.rand(1, 1, dtype=torch.float32)
b = torch.rand(1, 1, dtype=torch.float32)
c = torch.zeros(1, 1, dtype=torch.float32) # accumulator
result = mul_bf16x3(a, b, c)
```
The emulation using BF16x3 is used when invoking tl.dot with input
precision 'BF16x3'. This pass is implemented in a GPU agnostic manner,
but it is needed support for MI350's lack of TF32 support. This part is
a work in progress but will be based on this patch.
0 commit comments