Skip to content

Commit c186592

Browse files
authored
[FRONTEND] Add scales dimension checks for dot_scaled (#8564)
1 parent 3f5eb50 commit c186592

File tree

7 files changed

+101
-8
lines changed

7 files changed

+101
-8
lines changed

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -732,6 +732,7 @@ def TT_DotScaledOp : TT_Op<"dot_scaled", [Pure,
732732
`lhs` `=` $a_elem_type `rhs` `=` $b_elem_type attr-dict
733733
`:` type($a) (`,` type($a_scale)^)? `*` type($b) (`,` type($b_scale)^)? `->` type($d)
734734
}];
735+
let hasVerifier = 1;
735736
}
736737

737738
//

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,44 @@ bool DotScaledOp::verifyOutputDims() {
356356
return true;
357357
}
358358

359+
LogicalResult DotScaledOp::verify() {
360+
auto aShape = this->getA().getType().getShape();
361+
int64_t rank = aShape.size();
362+
363+
auto k = aShape[rank - 1];
364+
if (this->getAElemType() == ScaleDotElemType::E2M1) {
365+
if (this->getLhsKPack())
366+
k *= 2;
367+
}
368+
auto cShape = this->getC().getType().getShape();
369+
int64_t mDim = cShape[cShape.size() - 2];
370+
int64_t nDim = cShape[cShape.size() - 1];
371+
372+
if (getAScale()) {
373+
auto aScaleShape = getAScale().getType().getShape();
374+
if (aScaleShape[rank - 2] != mDim)
375+
return this->emitError(
376+
"scales M dimension must match the operand M dimension");
377+
int scale_factor =
378+
isa<Float8E4M3FNType>(getAScale().getType().getElementType()) ? 16 : 32;
379+
if (aScaleShape[rank - 1] != k / scale_factor)
380+
return this->emitError("scales K dimension must match the operand K "
381+
"divided by the scale factor");
382+
}
383+
if (getBScale()) {
384+
auto bScaleShape = getBScale().getType().getShape();
385+
if (bScaleShape[rank - 2] != nDim)
386+
return this->emitError(
387+
"scales N dimension must match the operand N dimension");
388+
int scale_factor =
389+
isa<Float8E4M3FNType>(getBScale().getType().getElementType()) ? 16 : 32;
390+
if (bScaleShape[rank - 1] != k / scale_factor)
391+
return this->emitError("scales K dimension must match the operand K "
392+
"divided by the scale factor");
393+
}
394+
return success();
395+
}
396+
359397
//-- MakeRangeOp --
360398
OpFoldResult MakeRangeOp::fold(FoldAdaptor adaptor) {
361399
// make_range(start, start + 1) -> constant(start)

python/test/unit/language/test_compile_errors.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,3 +489,23 @@ def kernel(N: tl.constexpr):
489489

490490
with pytest.raises(CompilationError, match="N marked as constexpr and listed in do_not_specialize"):
491491
kernel[(1, )](5)
492+
493+
494+
def test_dot_scaled_shape_verification(fresh_triton_cache):
495+
496+
@triton.jit
497+
def kernel():
498+
M: tl.constexpr = 32
499+
K: tl.constexpr = 64
500+
N: tl.constexpr = 32
501+
a = tl.full((M, K), 0, tl.uint8)
502+
b = tl.full((K, N), 0, tl.uint8)
503+
lhs_scale_wrong = tl.full((M, 4), 0, tl.uint8)
504+
rhs_scale = tl.full((N, 2), 0, tl.uint8)
505+
acc = tl.full((M, N), 0.0, tl.float32)
506+
tl.dot_scaled(a, lhs_scale_wrong, "e5m2", b, rhs_scale, "e5m2", acc, False, True, True, tl.float32)
507+
508+
with pytest.raises(CompilationError) as e:
509+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
510+
511+
assert str(e.value.__cause__) == "lhs_scale must be a tensor of shape [32, 2]. Got ['32', '4']"

python/triton/language/semantic.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1590,6 +1590,20 @@ def _bitcast_to_fp_type(self, val: TensorTy, float_format: str):
15901590
assert val.dtype == unsigned_ty, f"Unexpected dtype for {float_format}. Got {val.dtype}"
15911591
return self.bitcast(val, triton_ty)
15921592

1593+
def verify_scaled_shape(self, M, N, K, lhs_scale, rhs_scale):
1594+
if lhs_scale is not None:
1595+
scale_factor = 16 if lhs_scale.dtype.is_fp8e4nv() else 32
1596+
lhs_scale_shape = lhs_scale.type.shape
1597+
assert lhs_scale_shape == [
1598+
M, K // scale_factor
1599+
], f"lhs_scale must be a tensor of shape [{M}, {K // scale_factor}]. Got {lhs_scale_shape}"
1600+
if rhs_scale is not None:
1601+
scale_factor = 16 if rhs_scale.dtype.is_fp8e4nv() else 32
1602+
rhs_scale_shape = rhs_scale.type.shape
1603+
assert rhs_scale_shape == [
1604+
N, K // scale_factor
1605+
], f"rhs_scale must be a tensor of shape [{N}, {K // scale_factor}]. Got {rhs_scale_shape}"
1606+
15931607
def dot_scaled(self, lhs: TensorTy, lhs_scale: TensorTy, lhs_format: str, rhs: TensorTy,
15941608
rhs_scale: Optional[TensorTy], rhs_format: str, acc: TensorTy | None, fast_math: bool,
15951609
lhs_k_pack: bool, rhs_k_pack: bool, out_dtype: tl.dtype) -> TensorTy:
@@ -1621,8 +1635,11 @@ def dot_scaled(self, lhs: TensorTy, lhs_scale: TensorTy, lhs_format: str, rhs: T
16211635
assert PACKED_B_DIM == PACKED_A_DIM, f"Reduction dimension should pack the same number of elements; (lhs: {lhs.shape} vs rhs: {rhs.shape})"
16221636
#assert K * PACKED_B >= 64, f"scaled_dot NYI for K < 64. Got {K=}"
16231637
B = lhs.type.shape[0] if lhs_rank == 3 else None
1638+
K = K_LHS
16241639
if not lhs_k_pack:
16251640
M = M * PACKED_A
1641+
else:
1642+
K = K * PACKED_A
16261643
if not rhs_k_pack:
16271644
N = N * PACKED_B
16281645
ret_ty = tl.block_type(out_dtype, [B, M, N] if B else [M, N])
@@ -1634,6 +1651,8 @@ def dot_scaled(self, lhs: TensorTy, lhs_scale: TensorTy, lhs_format: str, rhs: T
16341651
assert acc.type.shape == ret_ty.shape and acc.type.element_ty == out_dtype
16351652
rhs_scale_handle = None if rhs_scale_is_none else rhs_scale.handle
16361653
lhs_scale_handle = None if lhs_scale_is_none else lhs_scale.handle
1654+
self.verify_scaled_shape(M, N, K, None if lhs_scale_is_none else lhs_scale,
1655+
None if rhs_scale_is_none else rhs_scale)
16371656
return self.tensor(
16381657
self.builder.create_dot_scaled(lhs.handle, lhs_scale_handle, lhs_format_enum, rhs.handle, rhs_scale_handle,
16391658
rhs_format_enum, fast_math, lhs_k_pack, rhs_k_pack, acc_handle), ret_ty)

test/Conversion/tritongpu_to_llvm_sm120.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,17 @@ module attributes {"ttg.target" = "cuda:120", "ttg.num-ctas" = 1 : i32, "ttg.num
99
// CHECK: mma.sync.aligned.m16n8k32.row.col.kind::mxf8f6f4.block_scale.scale_vec::1X
1010
tt.func public @sm120_mmav2_dot_scaled(
1111
%a: tensor<128x32xf8E5M2, #blocked_k>,
12-
%sa: tensor<128x2xi8, #blocked>,
12+
%sa: tensor<128x1xi8, #blocked>,
1313
%b: tensor<32x128xf8E5M2, #blocked>,
14-
%sb: tensor<128x2xi8, #blocked>,
14+
%sb: tensor<128x1xi8, #blocked>,
1515
%out: !tt.ptr<f32>
1616
){
1717
%c = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked>
1818
%a_d = ttg.convert_layout %a : tensor<128x32xf8E5M2, #blocked_k> -> tensor<128x32xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
1919
%b_d = ttg.convert_layout %b : tensor<32x128xf8E5M2, #blocked> -> tensor<32x128xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
2020
%d = tt.dot_scaled %a_d scale %sa, %b_d scale %sb, %c lhs = e5m2 rhs = e5m2 {fastMath = false}
21-
: tensor<128x32xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, tensor<128x2xi8, #blocked>
22-
* tensor<32x128xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>, tensor<128x2xi8, #blocked>
21+
: tensor<128x32xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, tensor<128x1xi8, #blocked>
22+
* tensor<32x128xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>, tensor<128x1xi8, #blocked>
2323
-> tensor<128x128xf32, #blocked>
2424
%out_splat = tt.splat %out : !tt.ptr<f32> -> tensor<128x1x!tt.ptr<f32>, #blocked>
2525
%out_ptrs = tt.broadcast %out_splat : tensor<128x1x!tt.ptr<f32>, #blocked> -> tensor<128x128x!tt.ptr<f32>, #blocked>

test/Triton/invalid.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,21 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32}
541541

542542
// -----
543543

544+
module {
545+
tt.func @dot_scaled_invalid_dims(
546+
%a: tensor<128x128xf8E4M3FN>,
547+
%b: tensor<128x128xf8E4M3FN>,
548+
%a_scale: tensor<128x128xi8>,
549+
%b_scale: tensor<128x4xi8>) -> tensor<128x128xf32> {
550+
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32>
551+
// expected-error @below {{scales K dimension must match the operand K divided by the scale factor}}
552+
%result = tt.dot_scaled %a scale %a_scale, %b scale %b_scale, %cst lhs = e4m3 rhs = e4m3 {fastMath = true} : tensor<128x128xf8E4M3FN>, tensor<128x128xi8> * tensor<128x128xf8E4M3FN>, tensor<128x4xi8>-> tensor<128x128xf32>
553+
tt.return %result : tensor<128x128xf32>
554+
}
555+
}
556+
557+
// -----
558+
544559
tt.func @unsplat_invalid(%arg0: tensor<128xf32>) {
545560
// expected-error @below {{source tensor must have exactly one element}}
546561
%0 = tt.unsplat %arg0 : tensor<128xf32>

test/TritonGPU/accelerate-matmul.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -680,18 +680,18 @@ module attributes {"ttg.target" = "cuda:120", "ttg.num-ctas" = 1 : i32, "ttg.num
680680
// CHECK-LABEL: @sm120_dot_scaled_basic
681681
tt.func public @sm120_dot_scaled_basic(
682682
%a: tensor<128x32xi8, #blocked_k>,
683-
%scale_a: tensor<128x2xi8, #blocked>,
683+
%scale_a: tensor<128x1xi8, #blocked>,
684684
%b: tensor<32x128xi8, #blocked>,
685-
%scale_b: tensor<128x2xi8, #blocked>
685+
%scale_b: tensor<128x1xi8, #blocked>
686686
) -> tensor<128x128xf32, #blocked> {
687687
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked>
688688
// CHECK-DAG: tt.dot_scaled
689689
// CHECK-DAG: #linear
690690
// CHECK-DAG: #linear1
691691
// CHECK-NOT: ttng.tc_gen5_mma_scaled
692692
%d = tt.dot_scaled %a scale %scale_a, %b scale %scale_b, %cst lhs = e4m3 rhs = e4m3 {fastMath = false}
693-
: tensor<128x32xi8, #blocked_k>, tensor<128x2xi8, #blocked>
694-
* tensor<32x128xi8, #blocked>, tensor<128x2xi8, #blocked>
693+
: tensor<128x32xi8, #blocked_k>, tensor<128x1xi8, #blocked>
694+
* tensor<32x128xi8, #blocked>, tensor<128x1xi8, #blocked>
695695
-> tensor<128x128xf32, #blocked>
696696
tt.return %d : tensor<128x128xf32, #blocked>
697697
}

0 commit comments

Comments
 (0)