Skip to content

Add Tensor Layout verifier for DPAS layout #4339

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions test/TritonIntelGPU/tritonintelgpu-invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,59 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.thr
tt.return %res : tensor<8x16xf16>
}
}


// -----

// COM: A operand mismatch
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [1, 1]}>
#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#dpas, kWidth=1}>
#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#dpas, kWidth=2}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
// CHECK-LABEL: matmul_tf32dot
tt.func @matmul_tf32dot(%ptr:!tt.ptr<f32>,
%a_mat:tensor<32x16xf32, #dot_operand_a>, %b_mat:tensor<16x32xf32, #dot_operand_b>) {
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #dpas>

// expected-error @+1 {{Layout has opsPerChannel = 2 but tensor element type is 'f32'. Expected 16 bit type.}}
%28 = tt.dot %a_mat, %b_mat, %cst, inputPrecision = tf32 : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #dpas>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
%28 = tt.dot %a_mat, %b_mat, %cst, inputPrecision = tf32 : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #dpas>
%28 = tt.dot %a_mat, %b_mat, %cst, inputPrecision = tf32 : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf16, #dot_operand_b> -> tensor<32x32xf32, #dpas>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just keep it the way it was - I am comfortable knowing it tests both permutations and I think the likelihood that they get changed is very low.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't feel strongly about this. Does the test fail with the suggested changes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the original set of suggested changes caused a failure but I did not reproduce locally. It seemed safer and more efficient to use the original test which produces the expected results, and which I verified carefully by disabling validation on A and/or B.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just tried locally, and it works, it likely failed because line 172 is changed but line 176 is not.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's leave it as is - I don't want to have to re-rest and the CI just finished.


tt.return
}
}

// -----

// COM: B operand mismatch
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [1, 1]}>
#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#dpas, kWidth=1}>
#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#dpas, kWidth=2}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
// CHECK-LABEL: matmul_tf32dot
tt.func @matmul_tf32dot(%ptr:!tt.ptr<f32>,
%a_mat:tensor<32x16xf16, #dot_operand_a>, %b_mat:tensor<16x32xf32, #dot_operand_b>) {
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #dpas>

// expected-error @+1 {{Layout has opsPerChannel = 2 but tensor element type is 'f32'. Expected 16 bit type.}}
%28 = tt.dot %a_mat, %b_mat, %cst, inputPrecision = tf32 : tensor<32x16xf16, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #dpas>

tt.return
}
}

// -----

#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [1, 1]}>
#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#dpas, kWidth=1}>
// expected-error @below {{ttg.dot_op kWidth parameter must match the parent's opsPerChannel}}
#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#dpas, kWidth=2}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
// CHECK-LABEL: matmul_tf32dot
tt.func @matmul_tf32dot(%ptr:!tt.ptr<f32>,
%a_mat:tensor<32x16xf32, #dot_operand_a>, %b_mat:tensor<16x32xf32, #dot_operand_b>) {
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #dpas>
%28 = tt.dot %a_mat, %b_mat, %cst, inputPrecision = tf32 : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #dpas>

tt.return
}
}
52 changes: 52 additions & 0 deletions third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1203,6 +1203,57 @@ struct TritonIntelGPUInferLayoutInterface
}
};

struct TritonIntelGPUVerifyTensorLayoutInterface
: public triton::DialectVerifyTensorLayoutInterface {
using DialectVerifyTensorLayoutInterface::DialectVerifyTensorLayoutInterface;

LogicalResult verifyTensorLayout(
Attribute layout, RankedTensorType rankedTy, Operation *op,
function_ref<InFlightDiagnostic()> makeErr) const override {

// Verify that the DPAS layout opsPerChannel param matches the A and B
// operand types. Because the DotOperand layout is not part of the Triton
// Intel GPU dialect, we need to first check for a tt.dot operation. Then,
// we can compare the type of each operand to the Dot operation with the
// DPAS layout attached to the Dot operation.
if (auto dpasEncoding = dyn_cast<DpasEncodingAttr>(layout)) {

auto validateDotDpasLayout = [&](Type elemTy) -> LogicalResult {
if (auto ptrTy = dyn_cast<PointerType>(elemTy)) {
elemTy = ptrTy.getPointeeType();
}
const unsigned elemTyBitWidth = elemTy.getIntOrFloatBitWidth();

// We know opsPerChannel is either 1, 4, or 8 because of the DPAS
// verifier when the DPAS attribute is created. Here we verify that
// opsPerChannel matches the tensor type.
if (dpasEncoding.getOpsPerChannel() * elemTyBitWidth != 32) {
return makeErr() << layout << ".\nLayout has opsPerChannel = "
<< dpasEncoding.getOpsPerChannel()
<< " but tensor element type is " << elemTy
<< ". Expected "
<< 32 / dpasEncoding.getOpsPerChannel()
<< " bit type.";
}
return success();
};

if (auto dotOp = dyn_cast<DotOp>(op)) {
auto aElemTy = dotOp.getA().getType().getElementType();
auto bElemTy = dotOp.getB().getType().getElementType();

auto aResult = validateDotDpasLayout(aElemTy);
if (aResult.failed())
return aResult;

return validateDotDpasLayout(bElemTy);
}
}

return success();
}
};

//===----------------------------------------------------------------------===//

void TritonIntelGPUDialect::initialize() {
Expand All @@ -1212,6 +1263,7 @@ void TritonIntelGPUDialect::initialize() {
>();

addInterfaces<TritonIntelGPUInferLayoutInterface>();
addInterfaces<TritonIntelGPUVerifyTensorLayoutInterface>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is an discussion with upstream Triton.
We want the third party dialect can use the TritonGPUVerify Interface as the parent class.
@LiyangLingIntel , Do you know what is the response of the upstream and what is the issue for the discussion.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The verify tensor layout interface is called via

Dialect &dialect = layout.getDialect();
    auto verifyLayoutInterface =
        dyn_cast<mlir::triton::DialectVerifyTensorLayoutInterface>(&dialect);
    if (verifyLayoutInterface) {
      return verifyLayoutInterface->verifyTensorLayout(layout, rankedTy, op,
                                                       makeErr);
    }

note that the dialect comes from the layout attribute and not the operation. Why would we need to call the Triton GPU dialect interface / use it as the parent class, when the layouts (attributes) it operates on are not part of our dialect?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know what is the response of the upstream and what is the issue for the discussion.

It seems there is no further design update after that discussion.

Why would we need to call the Triton GPU dialect interface / use it as the parent class, when the layouts (attributes) it operates on are not part of our dialect?

As what I can recall, there were some cases that layouts from Triton GPU dialect would also go into Triton Intel GPU dialect verify/infer layout interface. The reason "use the TritonGPUVerify Interface as the parent class" is to reuse common code to reduce duplication.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is some basic legal check implemented in the Triton GPU dialect interface which is valid for the third_party GPU dialect as well.

But right now those basic check is missed when to check the layout defined in third_party.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know about inferLayoutInterface, but I am curious about verifyLayoutInterface as the dialect comes directly from the layout, and as far as I know there are no layouts shared between dialects.

Now if the layout were a DotOperandEncoding layout with parent from the Intel dialect I could see how that might pose a problem, as DotOperandEncoding would never hit the Intel dialect verifier. But I don't understand how the opposite could be true.


addOperations<
#define GET_OP_LIST
Expand Down
Loading