Skip to content

Commit 462de12

Browse files
authored
Support scaled_dot with rhs scale (#5107)
This enables support for scaled_dot with rhs scale. It still only supports one scale either on lhs or rhs at the moment. For simplicity we just transpose operands for this case for MMAv2 and sync the transpose op on the destination. This prepares us for MMAv3 support where the scales should always be on lhs.
1 parent f9a4bbb commit 462de12

File tree

6 files changed

+263
-56
lines changed

6 files changed

+263
-56
lines changed

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ def TT_BitcastOp : TT_Op<"bitcast", [Elementwise,
8686
// TODO: Add verifier
8787
}
8888

89-
def TT_FpToFpOp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
89+
def TT_FpToFpOp : TT_Op<"fp_to_fp", [Elementwise,
90+
SameOperandsAndResultShape,
9091
SameOperandsAndResultEncoding,
9192
Pure,
9293
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
@@ -675,6 +676,7 @@ def TT_DotOp : TT_Op<"dot", [Pure,
675676
// DotScaled Op
676677
//
677678
def TT_DotScaledOp : TT_Op<"dot_scaled", [Pure,
679+
AttrSizedOperandSegments,
678680
DotLike,
679681
TypesMatchWith<"result's type matches accumulator's type",
680682
"d", "c", "$_self">]> {
@@ -692,7 +694,7 @@ def TT_DotScaledOp : TT_Op<"dot_scaled", [Pure,
692694
RankedTensorOf<[TT_Float,I8]>:$lhs,
693695
RankedTensorOf<[TT_Float,I8]>:$rhs,
694696
TT_FloatTensor:$c,
695-
RankedTensorOf<[I8]>:$lhs_scale,
697+
Optional<RankedTensorOf<[I8]>>:$lhs_scale,
696698
Optional<RankedTensorOf<[I8]>>:$rhs_scale,
697699
TT_ScaleDotElemTypeAttr:$lhs_type,
698700
TT_ScaleDotElemTypeAttr:$rhs_type
@@ -702,8 +704,8 @@ def TT_DotScaledOp : TT_Op<"dot_scaled", [Pure,
702704

703705
// Not sure why I need to fully specify the optional group, but otherwise it complains when loading the mlir file
704706
let assemblyFormat = [{
705-
$lhs `,` $lhs_scale `,` $rhs (`,`) : (`,` $rhs_scale^ `,`)? $c `lhs` `=` $lhs_type `rhs` `=` $rhs_type attr-dict
706-
`:` type($lhs) `,` type($lhs_scale) `*` type($rhs) (`,` type($rhs_scale)^)? `->` type($d)
707+
$lhs (`scale` $lhs_scale^)? `,` $rhs (`scale` $rhs_scale^)? `,` $c `lhs` `=` $lhs_type `rhs` `=` $rhs_type attr-dict
708+
`:` type($lhs) (`,` type($lhs_scale)^)? `*` type($rhs) (`,` type($rhs_scale)^)? `->` type($d)
707709
}];
708710
}
709711

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,147 @@ class DecomposeScaledBlocked
567567
}
568568
};
569569

570+
static void updateValueType(Value v, Attribute encoding,
571+
ArrayRef<int64_t> shape) {
572+
auto tensorType = cast<RankedTensorType>(v.getType());
573+
auto newType =
574+
RankedTensorType::get(shape, tensorType.getElementType(), encoding);
575+
v.setType(newType);
576+
}
577+
578+
static TransOp updateUsers(Value result, const SetVector<Operation *> &slice) {
579+
TransOp transOp;
580+
if (llvm::any_of(result.getUsers(),
581+
[&](Operation *user) { return slice.count(user) == 0; })) {
582+
OpBuilder builder(result.getContext());
583+
builder.setInsertionPointAfterValue(result);
584+
transOp =
585+
builder.create<TransOp>(result.getLoc(), result, ArrayRef({1, 0}));
586+
result.replaceUsesWithIf(transOp.getResult(), [&](OpOperand &operand) {
587+
return operand.getOwner() != transOp.getOperation() &&
588+
slice.count(operand.getOwner()) == 0;
589+
});
590+
}
591+
return transOp;
592+
}
593+
594+
// Sync the transpose in the IR, this is done to avoid generating convert layout
595+
// when we have a transpose right after a dot as mma layout cannot be propagated
596+
// through transpose op. Once we have layouts that can represent transposed MMA
597+
// we can remove this transformation.
598+
static void sinkTransposeOp(TransOp input) {
599+
SmallVector<TransOp> queue = {input};
600+
while (!queue.empty()) {
601+
TransOp transOp = queue.back();
602+
Value currentValue = transOp.getResult();
603+
queue.pop_back();
604+
mlir::ForwardSliceOptions options;
605+
options.filter = [](Operation *op) {
606+
if (op->hasTrait<OpTrait::Elementwise>() && op->getNumOperands() == 1)
607+
return true;
608+
if (isa<scf::YieldOp>(op))
609+
return isa<scf::ForOp>(op->getParentOp());
610+
if (isa<ConvertLayoutOp>(op))
611+
return true;
612+
return false;
613+
};
614+
SetVector<Operation *> slice;
615+
mlir::getForwardSlice(currentValue, &slice, options);
616+
for (Operation *op : slice) {
617+
if (op->hasTrait<OpTrait::Elementwise>()) {
618+
// Update users of transpose op.
619+
if (op->getOperand(0) == transOp.getResult())
620+
op->setOperand(0, transOp.getOperand());
621+
// Update the type of the result.
622+
for (Value result : op->getResults()) {
623+
auto srcType = cast<RankedTensorType>(op->getOperand(0).getType());
624+
updateValueType(result, srcType.getEncoding(), srcType.getShape());
625+
updateUsers(result, slice);
626+
}
627+
continue;
628+
}
629+
if (auto cvtOp = dyn_cast<ConvertLayoutOp>(op)) {
630+
// Update users of transpose op.
631+
if (op->getOperand(0) == transOp.getResult())
632+
op->setOperand(0, transOp.getOperand());
633+
auto resultEncoding = cvtOp.getType().getEncoding();
634+
auto newDstEncoding = inferSrcEncoding(transOp, resultEncoding);
635+
auto srcType = cast<RankedTensorType>(cvtOp.getOperand().getType());
636+
updateValueType(cvtOp.getResult(), *newDstEncoding, srcType.getShape());
637+
updateUsers(cvtOp.getResult(), slice);
638+
continue;
639+
}
640+
assert(isa<scf::YieldOp>(op));
641+
auto forOp = dyn_cast<scf::ForOp>(op->getParentOp());
642+
assert(forOp);
643+
for (OpOperand &operand : op->getOpOperands()) {
644+
Operation *def = operand.get().getDefiningOp();
645+
if (def && (slice.count(def)) || def == transOp.getOperation()) {
646+
if (def == transOp.getOperation())
647+
operand.set(transOp.getOperand());
648+
Type newType = operand.get().getType();
649+
forOp.getResult(operand.getOperandNumber()).setType(newType);
650+
TransOp retTrans =
651+
updateUsers(forOp.getResult(operand.getOperandNumber()), slice);
652+
// Recursively try to propagate the new transpose inserted.
653+
if (retTrans)
654+
queue.push_back(retTrans);
655+
forOp.getRegionIterArg(operand.getOperandNumber()).setType(newType);
656+
TransOp argTrans = updateUsers(
657+
forOp.getRegionIterArg(operand.getOperandNumber()), slice);
658+
if (argTrans)
659+
queue.push_back(argTrans);
660+
OpBuilder builder(forOp);
661+
OpOperand &init = forOp.getInitsMutable()[operand.getOperandNumber()];
662+
Value initTranspose = builder.create<TransOp>(
663+
forOp.getLoc(), init.get(), ArrayRef({1, 0}));
664+
init.set(initTranspose);
665+
}
666+
}
667+
}
668+
}
669+
}
670+
671+
// Transpose scaled_dot ops that have a scale on lhs.
672+
static Operation *transposeDotOp(DotScaledOp dotOp) {
673+
OpBuilder builder(dotOp);
674+
Value lhs = dotOp.getLhs();
675+
std::array<int, 2> transOrder = {1, 0};
676+
Value lhsTransposed = builder.create<TransOp>(lhs.getLoc(), lhs, transOrder);
677+
Value rhs = dotOp.getRhs();
678+
Value rhsTransposed = builder.create<TransOp>(rhs.getLoc(), rhs, transOrder);
679+
Value c = dotOp.getC();
680+
Value cTransposed = builder.create<TransOp>(c.getLoc(), c, transOrder);
681+
Value result = builder.create<DotScaledOp>(
682+
dotOp.getLoc(), cTransposed.getType(), rhsTransposed, lhsTransposed,
683+
cTransposed, dotOp.getRhsScale(), dotOp.getLhsScale(), dotOp.getRhsType(),
684+
dotOp.getLhsType());
685+
Operation *transposedResult =
686+
builder.create<TransOp>(result.getLoc(), result, transOrder);
687+
dotOp.replaceAllUsesWith(transposedResult);
688+
dotOp.erase();
689+
return transposedResult;
690+
}
691+
692+
static void transposeDots(ModuleOp m) {
693+
// TODO: extend to regular dot when it is profitable. For instance when we may
694+
// want to use rhs from register for mmav3.
695+
SmallVector<DotScaledOp> toTranspose;
696+
m.walk([&](DotScaledOp dotOp) -> void {
697+
if (dotOp.getLhsScale() == nullptr && dotOp.getRhsScale() != nullptr)
698+
toTranspose.push_back(dotOp);
699+
});
700+
SmallVector<Operation *> transposes;
701+
for (DotScaledOp dotOp : toTranspose) {
702+
Operation *transpose = transposeDotOp(dotOp);
703+
transposes.push_back(transpose);
704+
}
705+
706+
for (Operation *transpose : transposes) {
707+
sinkTransposeOp(cast<TransOp>(transpose));
708+
}
709+
}
710+
570711
#define GEN_PASS_DEF_TRITONGPUACCELERATEMATMUL
571712
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
572713

@@ -582,6 +723,7 @@ class TritonGPUAccelerateMatmulPass
582723
ModuleOp m = getOperation();
583724

584725
auto computeCapability = getNVIDIAComputeCapability(m);
726+
transposeDots(m);
585727

586728
mlir::RewritePatternSet patterns(context);
587729
patterns.add<BlockedToMMA, DecomposeScaledBlocked>(context,

python/src/ir.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1481,12 +1481,13 @@ void init_triton_ir(py::module &&m) {
14811481
maxNumImpreciseAcc);
14821482
})
14831483
.def("create_dot_scaled",
1484-
[](TritonOpBuilder &self, mlir::Value &lhs, mlir::Value &lhs_scale,
1484+
[](TritonOpBuilder &self, mlir::Value &lhs,
1485+
std::optional<mlir::Value> &lhs_scale,
14851486
ScaleDotElemType lhs_format, mlir::Value &rhs,
14861487
std::optional<mlir::Value> &rhs_scale,
14871488
ScaleDotElemType rhs_format, mlir::Value &c) -> mlir::Value {
14881489
return self.create<DotScaledOp>(
1489-
c.getType(), lhs, rhs, c, lhs_scale,
1490+
c.getType(), lhs, rhs, c, lhs_scale.value_or(Value()),
14901491
rhs_scale.value_or(Value()), lhs_format, rhs_format);
14911492
})
14921493
.def("create_floor",

python/test/unit/language/test_core.py

Lines changed: 65 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -3367,48 +3367,55 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid
33673367
assert 'wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3' in ptx
33683368

33693369

3370-
@pytest.mark.parametrize("M, N, K, col_a, col_b, type_a, type_b, num_warps, mma, kpack",
3371-
[(M, N, K, col_a, col_b, type_a, type_b, 4, mma, kpack)
3370+
@pytest.mark.parametrize("M, N, K, col_a, col_b, rhs_scale, normal_type, mxfp_type, num_warps, mma, kpack",
3371+
[(M, N, K, col_a, col_b, rhs_scale, normal_type, mxfp_type, 4, mma, kpack)
33723372
for M, N, K in itertools.product([32, 64, 128], [32, 64, 128], [64, 128])
33733373
for col_a, col_b in itertools.product([True, False], repeat=2)
3374-
for type_a in ["e2m1", "e4m3", "e5m2"]
3375-
for type_b in ["e4m3", "e5m2", "bf16"]
3374+
for rhs_scale in [False, True]
3375+
for normal_type in ["e2m1", "e4m3", "e5m2"]
3376+
for mxfp_type in ["e4m3", "e5m2", "bf16"]
33763377
for mma in ([32, 16] if is_hip() else [16])
33773378
for kpack in ([1, 2] if is_hip() else [1])])
3378-
def test_scaled_dot(M, N, K, col_a, col_b, type_a, type_b, num_warps, mma, kpack, device):
3379+
def test_scaled_dot(M, N, K, col_a, col_b, rhs_scale, normal_type, mxfp_type, num_warps, mma, kpack, device):
33793380
if is_cuda():
33803381
cc = torch.cuda.get_device_capability()
33813382
if cc < (8, 9):
33823383
pytest.skip("float8e4nv not supported on CUDA < 8.9")
33833384
if is_hip():
3385+
if rhs_scale:
3386+
pytest.skip("scales on rhs not yet support for HIP")
33843387
if not is_hip_cdna():
33853388
pytest.skip("scaled_dot only implemented for HIP CDNA")
3386-
if "e4m3" in (type_a, type_b) and not is_hip_mi300():
3387-
pytest.skip(f"scaled_dot({type_a}, {type_b}) only implemented for MI300")
3389+
if "e4m3" in (normal_type, mxfp_type) and not is_hip_mi300():
3390+
pytest.skip(f"scaled_dot({normal_type}, {mxfp_type}) only implemented for MI300")
33883391
if mma == 16 and K == 64:
33893392
pytest.skip(f"K == {K} too small for mfma {mma} in scaled_dot")
33903393

33913394
@triton.jit
3392-
def dot_scale_kernel(a_base, stride_a0, stride_a1, a_scale, b_base, stride_b0, stride_b1, out,
3395+
def dot_scale_kernel(a_base, stride_a0, stride_a1, a_scale, b_base, stride_b0, stride_b1, b_scale, out,
33933396
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, type_a: tl.constexpr,
33943397
type_b: tl.constexpr):
3395-
tl.static_assert((type_b == "e4m3" or type_b == "e5m2") or type_b == "bf16", "type_b must be fp8 or bf16")
3396-
IS_FP8: tl.constexpr = type_a == "e4m3" or type_a == "e5m2"
3397-
DIV_FACTOR: tl.constexpr = 1 if IS_FP8 else 2
3398-
PACKED_BLOCK_K_A: tl.constexpr = BLOCK_K // DIV_FACTOR
3399-
PACKED_BLOCK_K_B: tl.constexpr = BLOCK_K
3398+
DIV_FACTOR_A: tl.constexpr = 2 if type_a == "e2m1" else 1
3399+
DIV_FACTOR_B: tl.constexpr = 2 if type_b == "e2m1" else 1
3400+
PACKED_BLOCK_K_A: tl.constexpr = BLOCK_K // DIV_FACTOR_A
3401+
PACKED_BLOCK_K_B: tl.constexpr = BLOCK_K // DIV_FACTOR_B
34003402
a_ptr = a_base + tl.arange(0, BLOCK_M)[:, None] * stride_a0 + tl.arange(0,
34013403
PACKED_BLOCK_K_A)[None, :] * stride_a1
34023404
b_ptr = b_base + tl.arange(0, PACKED_BLOCK_K_B)[:, None] * stride_b0 + tl.arange(0,
34033405
BLOCK_N)[None, :] * stride_b1
34043406

3405-
SCALE_BLOCK_K: tl.constexpr = BLOCK_K // 32
3406-
scale_a_ptr = a_scale + tl.arange(0, BLOCK_M)[:, None] * SCALE_BLOCK_K + tl.arange(0, SCALE_BLOCK_K)[None, :]
3407-
34083407
a = tl.load(a_ptr)
34093408
b = tl.load(b_ptr)
3410-
a_scale = tl.load(scale_a_ptr)
3411-
c = tl.dot_scaled(a, a_scale, type_a, b, None, type_b)
3409+
SCALE_BLOCK_K: tl.constexpr = BLOCK_K // 32
3410+
if a_scale is not None:
3411+
scale_a_ptr = a_scale + tl.arange(0, BLOCK_M)[:, None] * SCALE_BLOCK_K + tl.arange(0,
3412+
SCALE_BLOCK_K)[None, :]
3413+
a_scale = tl.load(scale_a_ptr)
3414+
if b_scale is not None:
3415+
scale_b_ptr = b_scale + tl.arange(0, BLOCK_N)[:, None] * SCALE_BLOCK_K + tl.arange(0,
3416+
SCALE_BLOCK_K)[None, :]
3417+
b_scale = tl.load(scale_b_ptr)
3418+
c = tl.dot_scaled(a, a_scale, type_a, b, b_scale, type_b)
34123419
out_ptr = out + tl.arange(0, BLOCK_M)[:, None] * BLOCK_N + tl.arange(0, BLOCK_N)[None, :]
34133420
tl.store(out_ptr, c.to(tl.bfloat16))
34143421

@@ -3481,22 +3488,31 @@ def mxfp_to_bf16_kernel(
34813488
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
34823489
tl.store(mxfp_ptr + offsets, tl.ravel(mxfp), mask=offsets < N * 32)
34833490

3484-
def dot_scale_ref(x, scale, y, type_x, type_y):
3485-
e_bits, m_bits = {"e2m1": (2, 1), "e4m3": (4, 3), "e5m2": (5, 2)}[type_x]
3486-
type_y = {"e4m3": torch.float8_e4m3fn, "e5m2": torch.float8_e5m2, "bf16": torch.bfloat16}[type_y]
3487-
3488-
comp_dtype = torch.bfloat16
3489-
3490-
x = x.contiguous()
3491-
x_upcast = x.new_empty(scale.shape[:-1] + (32 * scale.shape[-1], ), dtype=comp_dtype)
3492-
3493-
N = x_upcast.numel()
3494-
BLOCK_SIZE = 512
3495-
grid = ((N + BLOCK_SIZE - 1) // BLOCK_SIZE, )
3496-
mxfp_to_bf16_kernel[grid](x, scale, x_upcast, scale.numel(), e_bits, m_bits, BLOCK_SIZE, num_warps=num_warps)
3497-
assert x_upcast.isfinite().all()
3498-
3499-
y_upcast = y.view(type_y).to(comp_dtype)
3491+
def dot_scale_ref(x, scale_x, y, scale_y, type_x, type_y):
3492+
3493+
def upcast(v, scale, type, transposed):
3494+
comp_dtype = torch.bfloat16
3495+
if scale is None:
3496+
type = {"e4m3": torch.float8_e4m3fn, "e5m2": torch.float8_e5m2, "bf16": torch.bfloat16}[type]
3497+
return v.view(type).to(comp_dtype)
3498+
e_bits, m_bits = {"e2m1": (2, 1), "e4m3": (4, 3), "e5m2": (5, 2)}[type]
3499+
# Packing is always on the K dimension so we transpose before upcasting then transpose back.
3500+
if transposed:
3501+
v = v.mT.contiguous()
3502+
v = v.contiguous()
3503+
v_upcast = v.new_empty(scale.shape[:-1] + (32 * scale.shape[-1], ), dtype=comp_dtype)
3504+
N = v_upcast.numel()
3505+
BLOCK_SIZE = 512
3506+
grid = ((N + BLOCK_SIZE - 1) // BLOCK_SIZE, )
3507+
mxfp_to_bf16_kernel[grid](v, scale, v_upcast, scale.numel(), e_bits, m_bits, BLOCK_SIZE,
3508+
num_warps=num_warps)
3509+
assert v_upcast.isfinite().all()
3510+
if transposed:
3511+
v_upcast = v_upcast.mT
3512+
return v_upcast
3513+
3514+
x_upcast = upcast(x, scale_x, type_x, False)
3515+
y_upcast = upcast(y, scale_y, type_y, True)
35003516

35013517
class AccumulateInFp32:
35023518

@@ -3525,13 +3541,22 @@ def make_arg(shape, ty, col_major=False, max_val=255):
35253541
ret = ret.mT
35263542
return ret
35273543

3528-
DIV_FACTOR = 2 if type_a == "e2m1" else 1
3529-
x = make_arg((M, K // DIV_FACTOR), type_a, col_major=col_a)
3530-
y = make_arg((K, N), type_b, col_major=col_b)
3544+
type_a = normal_type if not rhs_scale else mxfp_type
3545+
type_b = mxfp_type if not rhs_scale else normal_type
3546+
3547+
DIV_FACTOR_A = 2 if type_a == "e2m1" else 1
3548+
DIV_FACTOR_B = 2 if type_b == "e2m1" else 1
3549+
x = make_arg((M, K // DIV_FACTOR_A), type_a, col_major=col_a)
3550+
y = make_arg((K // DIV_FACTOR_B, N), type_b, col_major=col_b)
35313551

35323552
# sample scales that don't overflow as otherwise it's implementation defined (underflowing is alright)
35333553
# Max scale= 2**15
35343554
scale_x = make_arg((M, K // 32), "e8m0", max_val=127 + 15)
3555+
scale_y = make_arg((N, K // 32), "e8m0", max_val=127 + 15)
3556+
if rhs_scale:
3557+
scale_x = None
3558+
else:
3559+
scale_y = None
35353560

35363561
def make_finite(x, dtype):
35373562
# e5m2 has too many non-finite values when sampled uniformly (1 / 32) and
@@ -3546,16 +3571,14 @@ def make_finite(x, dtype):
35463571

35473572
x = make_finite(x, type_a)
35483573
y = make_finite(y, type_b)
3549-
35503574
kernel_kwargs = {"num_warps": num_warps}
35513575
if is_hip():
35523576
kernel_kwargs["kpack"] = kpack
35533577
kernel_kwargs["matrix_instr_nonkdim"] = mma
35543578
z = x.new_empty((M, N), dtype=torch.bfloat16)
3555-
pgm = dot_scale_kernel[(1, )](x, *x.stride(), scale_x, y, *y.stride(), z, M, N, K, type_a, type_b, **kernel_kwargs)
3556-
3557-
z_ref = dot_scale_ref(x, scale_x, y, type_a, type_b)
3558-
3579+
pgm = dot_scale_kernel[(1, )](x, *x.stride(), scale_x, y, *y.stride(), scale_y, z, M, N, K, type_a, type_b,
3580+
**kernel_kwargs)
3581+
z_ref = dot_scale_ref(x, scale_x, y, scale_y, type_a, type_b)
35593582
# Bigger tolerance for AMD MI200 devices.
35603583
# MI200 devices use reduced precision fp16 and bf16 and flush input and output denormal values
35613584
# to zero. Detailed info is at:

0 commit comments

Comments
 (0)