-
Notifications
You must be signed in to change notification settings - Fork 68
Add Tensor Layout verifier for DPAS layout #4339
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -158,3 +158,59 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.thr | |||||
tt.return %res : tensor<8x16xf16> | ||||||
} | ||||||
} | ||||||
|
||||||
|
||||||
// ----- | ||||||
|
||||||
// COM: A operand mismatch | ||||||
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [1, 1]}> | ||||||
#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#dpas, kWidth=1}> | ||||||
#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#dpas, kWidth=2}> | ||||||
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { | ||||||
// CHECK-LABEL: matmul_tf32dot | ||||||
tt.func @matmul_tf32dot(%ptr:!tt.ptr<f32>, | ||||||
%a_mat:tensor<32x16xf32, #dot_operand_a>, %b_mat:tensor<16x32xf32, #dot_operand_b>) { | ||||||
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #dpas> | ||||||
|
||||||
// expected-error @+1 {{Layout has opsPerChannel = 2 but tensor element type is 'f32'. Expected 16 bit type.}} | ||||||
%28 = tt.dot %a_mat, %b_mat, %cst, inputPrecision = tf32 : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #dpas> | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's just keep it the way it was - I am comfortable knowing it tests both permutations and I think the likelihood that they get changed is very low. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't feel strongly about this. Does the test fail with the suggested changes? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, the original set of suggested changes caused a failure but I did not reproduce locally. It seemed safer and more efficient to use the original test which produces the expected results, and which I verified carefully by disabling validation on A and/or B. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just tried locally, and it works, it likely failed because line 172 is changed but line 176 is not. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's leave it as is - I don't want to have to re-rest and the CI just finished. |
||||||
|
||||||
tt.return | ||||||
} | ||||||
} | ||||||
|
||||||
// ----- | ||||||
|
||||||
// COM: B operand mismatch | ||||||
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [1, 1]}> | ||||||
#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#dpas, kWidth=1}> | ||||||
#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#dpas, kWidth=2}> | ||||||
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { | ||||||
// CHECK-LABEL: matmul_tf32dot | ||||||
tt.func @matmul_tf32dot(%ptr:!tt.ptr<f32>, | ||||||
%a_mat:tensor<32x16xf16, #dot_operand_a>, %b_mat:tensor<16x32xf32, #dot_operand_b>) { | ||||||
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #dpas> | ||||||
|
||||||
// expected-error @+1 {{Layout has opsPerChannel = 2 but tensor element type is 'f32'. Expected 16 bit type.}} | ||||||
%28 = tt.dot %a_mat, %b_mat, %cst, inputPrecision = tf32 : tensor<32x16xf16, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #dpas> | ||||||
|
||||||
tt.return | ||||||
} | ||||||
} | ||||||
|
||||||
// ----- | ||||||
|
||||||
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [1, 1]}> | ||||||
whitneywhtsang marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#dpas, kWidth=1}> | ||||||
// expected-error @below {{ttg.dot_op kWidth parameter must match the parent's opsPerChannel}} | ||||||
#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#dpas, kWidth=2}> | ||||||
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { | ||||||
// CHECK-LABEL: matmul_tf32dot | ||||||
tt.func @matmul_tf32dot(%ptr:!tt.ptr<f32>, | ||||||
%a_mat:tensor<32x16xf32, #dot_operand_a>, %b_mat:tensor<16x32xf32, #dot_operand_b>) { | ||||||
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #dpas> | ||||||
%28 = tt.dot %a_mat, %b_mat, %cst, inputPrecision = tf32 : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #dpas> | ||||||
|
||||||
tt.return | ||||||
} | ||||||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1203,6 +1203,57 @@ struct TritonIntelGPUInferLayoutInterface | |
} | ||
}; | ||
|
||
struct TritonIntelGPUVerifyTensorLayoutInterface | ||
: public triton::DialectVerifyTensorLayoutInterface { | ||
using DialectVerifyTensorLayoutInterface::DialectVerifyTensorLayoutInterface; | ||
|
||
LogicalResult verifyTensorLayout( | ||
Attribute layout, RankedTensorType rankedTy, Operation *op, | ||
function_ref<InFlightDiagnostic()> makeErr) const override { | ||
|
||
// Verify that the DPAS layout opsPerChannel param matches the A and B | ||
// operand types. Because the DotOperand layout is not part of the Triton | ||
// Intel GPU dialect, we need to first check for a tt.dot operation. Then, | ||
// we can compare the type of each operand to the Dot operation with the | ||
// DPAS layout attached to the Dot operation. | ||
if (auto dpasEncoding = dyn_cast<DpasEncodingAttr>(layout)) { | ||
|
||
auto validateDotDpasLayout = [&](Type elemTy) -> LogicalResult { | ||
if (auto ptrTy = dyn_cast<PointerType>(elemTy)) { | ||
elemTy = ptrTy.getPointeeType(); | ||
} | ||
const unsigned elemTyBitWidth = elemTy.getIntOrFloatBitWidth(); | ||
|
||
// We know opsPerChannel is either 1, 4, or 8 because of the DPAS | ||
// verifier when the DPAS attribute is created. Here we verify that | ||
// opsPerChannel matches the tensor type. | ||
if (dpasEncoding.getOpsPerChannel() * elemTyBitWidth != 32) { | ||
return makeErr() << layout << ".\nLayout has opsPerChannel = " | ||
<< dpasEncoding.getOpsPerChannel() | ||
<< " but tensor element type is " << elemTy | ||
<< ". Expected " | ||
<< 32 / dpasEncoding.getOpsPerChannel() | ||
<< " bit type."; | ||
} | ||
return success(); | ||
}; | ||
|
||
if (auto dotOp = dyn_cast<DotOp>(op)) { | ||
auto aElemTy = dotOp.getA().getType().getElementType(); | ||
auto bElemTy = dotOp.getB().getType().getElementType(); | ||
|
||
auto aResult = validateDotDpasLayout(aElemTy); | ||
if (aResult.failed()) | ||
return aResult; | ||
|
||
return validateDotDpasLayout(bElemTy); | ||
} | ||
} | ||
|
||
return success(); | ||
} | ||
}; | ||
|
||
//===----------------------------------------------------------------------===// | ||
|
||
void TritonIntelGPUDialect::initialize() { | ||
|
@@ -1212,6 +1263,7 @@ void TritonIntelGPUDialect::initialize() { | |
>(); | ||
|
||
addInterfaces<TritonIntelGPUInferLayoutInterface>(); | ||
addInterfaces<TritonIntelGPUVerifyTensorLayoutInterface>(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is an discussion with upstream Triton. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The verify tensor layout interface is called via
note that the dialect comes from the layout attribute and not the operation. Why would we need to call the Triton GPU dialect interface / use it as the parent class, when the layouts (attributes) it operates on are not part of our dialect? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
It seems there is no further design update after that discussion.
As what I can recall, there were some cases that layouts from Triton GPU dialect would also go into Triton Intel GPU dialect verify/infer layout interface. The reason "use the TritonGPUVerify Interface as the parent class" is to reuse common code to reduce duplication. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is some basic legal check implemented in the Triton GPU dialect interface which is valid for the third_party GPU dialect as well. But right now those basic check is missed when to check the layout defined in third_party. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't know about Now if the layout were a |
||
|
||
addOperations< | ||
#define GET_OP_LIST | ||
|
Uh oh!
There was an error while loading. Please reload this page.