Skip to content

Commit 6a4be78

Browse files
authored
Pipeline scale_dot (#4950)
We allow DotOperand within MemoryOpToLLVM in the buggy ampere case via LLs. This allows us to remove two workarounds that we added in a previous PR. We add tests in test_pipeliner.py We also remove some implementation-defined behaviour (overflows / NaNs) in test_core.py, thus making the tests more resilient and realistic.
1 parent 50080ef commit 6a4be78

File tree

5 files changed

+269
-71
lines changed

5 files changed

+269
-71
lines changed

lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -90,16 +90,6 @@ void decomposeBlockedToDotLayoutConversion(ModuleOp module) {
9090
auto dstDotOp =
9191
dyn_cast<triton::gpu::DotOperandEncodingAttr>(dstType.getEncoding());
9292
if (srcBlocked && dstDotOp) {
93-
// FIXME [Dot LL]
94-
// We support this one via LLs, as the LocalLoad path is buggy
95-
if (auto mma = dyn_cast<NvidiaMmaEncodingAttr>(dstDotOp.getParent())) {
96-
bool largeKWidth =
97-
dstDotOp.getKWidth() * dstType.getElementTypeBitWidth() > 64;
98-
if (mma.isAmpere() && largeKWidth) {
99-
return;
100-
}
101-
}
102-
10393
Attribute sharedMemorySpace =
10494
triton::gpu::SharedMemorySpaceAttr::get(srcType.getContext());
10595
auto tmpType = MemDescType::get(

lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,20 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
116116
RankedTensorType dstTy = op.getType();
117117
Attribute srcLayout = srcTy.getEncoding();
118118
Attribute dstLayout = dstTy.getEncoding();
119+
// FIXME [Dot LL]
120+
// Do for all DotOperandEncodingAttr once we have LLs for all of them
121+
auto isAmpereLargeKWidth = [](Attribute layout) {
122+
if (auto dot = dyn_cast<DotOperandEncodingAttr>(layout)) {
123+
if (auto mma = dyn_cast<NvidiaMmaEncodingAttr>(dot.getParent())) {
124+
return mma.isAmpere() && dot.getKWidth() == 8;
125+
}
126+
}
127+
return false;
128+
};
119129
if (isa<SharedEncodingAttr>(srcLayout) &&
120-
isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr>(
121-
dstLayout)) {
130+
(isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr>(
131+
dstLayout) ||
132+
isAmpereLargeKWidth(dstLayout))) {
122133
return lowerSharedToDistributed(op, adaptor, getTypeConverter(),
123134
rewriter);
124135
}
@@ -170,6 +181,37 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
170181
SmallVector<Value> outVals = loadSharedToDistributed(
171182
dstTy, srcTy, elemLlvmTy, smemObj, loc, rewriter, targetInfo);
172183

184+
// FIXME [Dot LL]
185+
// Ampere case
186+
// In this case, we need to pack the outputs into i32
187+
if (isa<DotOperandEncodingAttr>(dstTy.getEncoding())) {
188+
if (elemLlvmTy.isInteger(8)) {
189+
auto concat = [&](Value a1, Value a2, Value a3, Value a4) {
190+
return or_(or_(zext(i32_ty, a1), shl(zext(i32_ty, a2), i32_val(8))),
191+
or_(shl(zext(i32_ty, a3), i32_val(16)),
192+
shl(zext(i32_ty, a4), i32_val(24))));
193+
};
194+
SmallVector<Value> outVals32(outVals.size() / 4);
195+
for (int i = 0; i < outVals32.size(); ++i) {
196+
outVals32[i] = concat(outVals[4 * i], outVals[4 * i + 1],
197+
outVals[4 * i + 2], outVals[4 * i + 3]);
198+
}
199+
outVals = outVals32;
200+
} else {
201+
assert(elemLlvmTy.isBF16() && "Unexpected element type");
202+
auto concat = [&](Value a, Value b) {
203+
return or_(zext(i32_ty, bitcast(a, i16_ty)),
204+
shl(zext(i32_ty, bitcast(b, i16_ty)), i32_val(16)));
205+
};
206+
207+
SmallVector<Value> outVals32(outVals.size() / 2);
208+
for (int i = 0; i < outVals32.size(); ++i) {
209+
outVals32[i] = concat(outVals[2 * i], outVals[2 * i + 1]);
210+
}
211+
outVals = outVals32;
212+
}
213+
}
214+
173215
Value result = packLLElements(loc, typeConverter, outVals, rewriter, dstTy);
174216
rewriter.replaceOp(op, result);
175217

lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,6 @@ class TritonGPUReduceDataDuplicationPass
4444
return;
4545
if (!cvtNeedsSharedMemory(srcType, dstType))
4646
return;
47-
// FIXME [Dot LL]
48-
// We support this one via LLs, as the LocalLoad path is buggy
49-
bool largeKWidth =
50-
dstDotOp.getKWidth() * dstType.getElementTypeBitWidth() > 64;
51-
if (largeKWidth) {
52-
return;
53-
}
5447
auto srcOrder = triton::gpu::getOrder(srcEncoding);
5548
auto rank = srcOrder.size();
5649
SmallVector<unsigned> sharedOrder;

python/test/unit/language/test_core.py

Lines changed: 46 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3315,16 +3315,12 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid
33153315
assert 'wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3' in ptx
33163316

33173317

3318-
@pytest.mark.parametrize("M, N, K, col_a, col_b, type_a, type_b, num_warps", [
3319-
(M, N, K, col_a, col_b, type_a, type_b, 4)
3320-
for M, N, K in itertools.product([32, 64, 128], [32, 64, 128], [64, 128])
3321-
for col_a, col_b in itertools.product([True, False], repeat=2)
3322-
# We don't test e5m2 as its range + the uniform sampling overflows easily
3323-
# Tested locally and it works fine other than for ~10 entries out of 10_000
3324-
# which are of the size of 10**30
3325-
for type_a in ["e2m1", "e4m3"]
3326-
for type_b in ["e4m3"]
3327-
])
3318+
@pytest.mark.parametrize("M, N, K, col_a, col_b, type_a, type_b, num_warps",
3319+
[(M, N, K, col_a, col_b, type_a, type_b, 4)
3320+
for M, N, K in itertools.product([32, 64, 128], [32, 64, 128], [64, 128])
3321+
for col_a, col_b in itertools.product([True, False], repeat=2)
3322+
for type_a in ["e2m1", "e4m3", "e5m2"]
3323+
for type_b in ["e4m3", "e5m2"]])
33283324
def test_scaled_dot(M, N, K, col_a, col_b, type_a, type_b, num_warps, device):
33293325
if not is_cuda():
33303326
pytest.skip("scaled_dot only supported on CUDA")
@@ -3355,7 +3351,7 @@ def dot_scale_kernel(a_base, stride_a0, stride_a1, a_scale, b_base, stride_b0, s
33553351
a_scale = tl.load(scale_a_ptr)
33563352
c = tl.dot_scaled(a, a_scale, type_a, b, None, type_b)
33573353
out_ptr = out + tl.arange(0, BLOCK_M)[:, None] * BLOCK_N + tl.arange(0, BLOCK_N)[None, :]
3358-
tl.store(out_ptr, c)
3354+
tl.store(out_ptr, c.to(tl.bfloat16))
33593355

33603356
@triton.jit
33613357
def mxfp_to_bf16_kernel(
@@ -3431,7 +3427,6 @@ def dot_scale_ref(x, scale, y, type_x, type_y):
34313427
type_fp8_y = {"e4m3": torch.float8_e4m3fn, "e5m2": torch.float8_e5m2}[type_y]
34323428

34333429
comp_dtype = torch.bfloat16
3434-
out_dtype = torch.float32
34353430

34363431
x = x.contiguous()
34373432
x_upcast = x.new_empty(scale.shape[:-1] + (32 * scale.shape[-1], ), dtype=comp_dtype)
@@ -3440,42 +3435,65 @@ def dot_scale_ref(x, scale, y, type_x, type_y):
34403435
BLOCK_SIZE = 512
34413436
grid = ((N + BLOCK_SIZE - 1) // BLOCK_SIZE, )
34423437
mxfp_to_bf16_kernel[grid](x, scale, x_upcast, scale.numel(), e_bits, m_bits, BLOCK_SIZE, num_warps=num_warps)
3438+
assert x_upcast.isfinite().all()
34433439

34443440
y_upcast = y.view(type_fp8_y).to(comp_dtype)
3445-
return torch.matmul(x_upcast.to(out_dtype), y_upcast.to(out_dtype))
3441+
3442+
class AccumulateInFp32:
3443+
3444+
def __enter__(self):
3445+
self.prev_value = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction
3446+
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
3447+
3448+
def __exit__(self, exc_type, exc_val, exc_tb):
3449+
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = self.prev_value
3450+
3451+
with AccumulateInFp32():
3452+
return torch.matmul(x_upcast.to(comp_dtype), y_upcast.to(comp_dtype))
34463453

34473454
torch.manual_seed(0)
34483455

3449-
def create_uint8(shape, col_major=False):
3456+
def create_uint8(shape, col_major=False, max_val=255):
34503457
if col_major:
34513458
shape = shape[:-2] + (shape[-1], shape[-2])
3452-
ret = torch.randint(1 << 8, shape, dtype=torch.uint8, device=device)
3459+
ret = torch.randint(max_val + 1, shape, dtype=torch.uint8, device=device)
34533460
if col_major:
34543461
ret = ret.mT
34553462
return ret
34563463

34573464
DIV_FACTOR = 2 if type_a == "e2m1" else 1
34583465
x = create_uint8((M, K // DIV_FACTOR), col_major=col_a)
34593466
y = create_uint8((K, N), col_major=col_b)
3460-
scale_x = create_uint8((M, K // 32))
34613467

3462-
z = x.new_empty((M, N), dtype=torch.float32)
3468+
# sample scales that don't overflow as otherwise it's implementation defined (underflowing is alright)
3469+
# We substract a reasonably high number (64) so that the sum of all the mxfp elements does not overflow
3470+
m_bytes = int(type_a[1])
3471+
bias_type_a = 1 << (m_bytes - 1) - 1
3472+
max_exponent_type_a = (1 << m_bytes) - 1 - bias_type_a
3473+
scale_x = create_uint8((M, K // 32), max_val=255 - max_exponent_type_a - 64)
3474+
3475+
def make_finite(x, dtype):
3476+
# e5m2 has too many non-finite values when sampled uniformly (1 / 32) and
3477+
# Fp8E5M2_to_Bf16 doesn't preserve NaNs (fixme)
3478+
if dtype not in ("e5m2", "e4m3"):
3479+
return x
3480+
mask = 0x7C if dtype == "e5m2" else 0x7F
3481+
finite = torch.arange(x.numel(), device=device, dtype=torch.uint8).reshape_as(x) % mask
3482+
x_finite = torch.where(x & mask == mask, finite | (0x80 & x), x)
3483+
x.copy_(x_finite)
3484+
return x
3485+
3486+
x = make_finite(x, type_a)
3487+
y = make_finite(y, type_b)
3488+
3489+
z = x.new_empty((M, N), dtype=torch.bfloat16)
34633490
pgm = dot_scale_kernel[(1, )](x, *x.stride(), scale_x, y, *y.stride(), z, M, N, K, type_a, type_b,
34643491
num_warps=num_warps)
34653492

34663493
z_ref = dot_scale_ref(x, scale_x, y, type_a, type_b)
34673494

3468-
# dot_scale_ref computes the result in higher precision
3469-
# so we equalise all the non-finite values
3470-
# This also fixes a bug in our upcasting from e5m2 to bf16 where inf is not preserved
3471-
non_finite_z = ~z.isfinite()
3472-
z_ref[non_finite_z] = z[non_finite_z]
3473-
non_finite_ref = ~z_ref.isfinite()
3474-
z[non_finite_ref] = z_ref[non_finite_ref]
3475-
3476-
# generous rtol set because the ref is more precise than the fused
3477-
# (computes in higher dtype) and we are sampling the whole range of floats
3478-
torch.testing.assert_close(z, z_ref, equal_nan=True, atol=1e-5, rtol=1e-2)
3495+
# generous rtol as we are sampling the whole range of floats
3496+
torch.testing.assert_close(z, z_ref, atol=1e-5, rtol=1e-2)
34793497

34803498
# make sure ld/st are vectorized
34813499
ptx = pgm.asm['ptx']

0 commit comments

Comments
 (0)