Skip to content

Commit 0c70ca3

Browse files
authored
Support for tt.dot_scaled operator (#2804)
This PR decomposed a `tt.dot_scaled` operation into a `tt.dot` operation where one of the operands (e.g A) is scaled using the `triton_gpu_upcast_mxfp` operation. Note: The `upcast_mxfp` operation is not lowered to LLVM IR in this PR. --------- Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent ef8ded8 commit 0c70ca3

File tree

6 files changed

+314
-75
lines changed

6 files changed

+314
-75
lines changed

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include "intel/include/Dialect/TritonIntelGPU/IR/Attributes.h"
12
#include "mlir/IR/BuiltinTypes.h"
23
#include "triton/Dialect/Triton/IR/Dialect.h"
34
#include "triton/Dialect/Triton/IR/Utility.h"
@@ -114,17 +115,36 @@ LogicalResult UpcastMXFPOp::inferReturnTypes(
114115
retTy = RankedTensorType::get(xShape, FloatType::getBF16(ctx));
115116
} else {
116117
auto oldEncoding = cast<DotOperandEncodingAttr>(encoding);
117-
auto newVEncoding = DotOperandEncodingAttr::get(
118-
ctx, oldEncoding.getOpIdx(), oldEncoding.getParent(),
119-
oldEncoding.getKWidth() * 2);
120-
// Figure out the K dimension for the input A/B, given that the return
121-
// type is upcasted A/B type so we need to update the proper dim size.
118+
122119
const int opIdx = oldEncoding.getOpIdx();
123120
const bool hasBatch = xShape.size() == 3;
124121
const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch;
125122
newShape[kIdx] *= 2;
126-
retTy = RankedTensorType::get(newShape, FloatType::getBF16(ctx),
127-
newVEncoding);
123+
Type elemType = FloatType::getBF16(ctx);
124+
125+
// Note: For Intel the dot operands layout's kWidth parameter must
126+
// match the parent's DPAS layout opsPerChannel so we need to materialize
127+
// a new DPAS layout.
128+
Attribute newVEncoding;
129+
if (auto dpasEncoding =
130+
dyn_cast<intel::DpasEncodingAttr>(oldEncoding.getParent())) {
131+
auto newDpasEncoding = intel::DpasEncodingAttr::get(
132+
ctx, dpasEncoding.getRepeatCount(), dpasEncoding.getSystolicDepth(),
133+
dpasEncoding.getExecutionSize(),
134+
intel::DpasEncodingAttr::getOpsPerChannel(elemType),
135+
dpasEncoding.getWarpsPerCTA(), dpasEncoding.getRepCluster(),
136+
dpasEncoding.getSubGroupSize());
137+
newVEncoding = DotOperandEncodingAttr::get(
138+
ctx, oldEncoding.getOpIdx(), newDpasEncoding,
139+
newDpasEncoding.getOpsPerChannel());
140+
} else {
141+
// Figure out the K dimension for the input A/B, given that the return
142+
// type is upcasted A/B type so we need to update the proper dim size.
143+
newVEncoding = DotOperandEncodingAttr::get(ctx, oldEncoding.getOpIdx(),
144+
oldEncoding.getParent(),
145+
oldEncoding.getKWidth() * 2);
146+
}
147+
retTy = RankedTensorType::get(newShape, elemType, newVEncoding);
128148
}
129149
inferredReturnTypes.push_back(retTy);
130150
} else {

test/TritonIntelGPU/accelerate-matmul-pvc.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,3 +201,32 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
201201
tt.return
202202
}
203203
}
204+
205+
// -----
206+
207+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
208+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
209+
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}>
210+
211+
module attributes {"ttg.target" = "xpu", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32, "triton_intel_gpu.min_sg_size" = 16 : i32, "triton_intel_gpu.support_dpas"} {
212+
// CHECK: [[BLOCKED:#.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}>
213+
// CHECK: [[BLOCKED1:#.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
214+
// CHECK: [[BLOCKED2:#.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
215+
// CHECK: [[DPAS:#.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 1], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
216+
// CHECK: [[DPAS1:#.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 4, threadsPerWarp = 16, warpsPerCTA = [4, 1], repCluster = [1, 1], A = [8, 32], B = [32, 16], C = [8, 16]}>
217+
// CHECK: dot_scaled
218+
tt.func @dot_scaled(%a: tensor<128x32xi8, #blocked2>, %scale: tensor<128x2xi8, #blocked1>, %b: tensor<64x128xbf16, #blocked>) -> tensor<128x128xf32, #blocked> {
219+
// CHECK: [[CST:%.*]] = arith.constant dense<0.000000e+00> : tensor<128x128xf32, [[BLOCKED2]]>
220+
// CHECK: [[C:%.*]] = ttg.convert_layout [[CST]] : tensor<128x128xf32, [[BLOCKED2]]> -> tensor<128x128xf32, [[DPAS]]>
221+
// CHECK: [[CVT_ARG0:%.*]] = ttg.convert_layout %arg0 : tensor<128x32xi8, [[BLOCKED]]> -> tensor<128x32xi8, #ttg.dot_op<{opIdx = 0, parent = [[DPAS1]], kWidth = 4}>>
222+
// CHECK: [[CVT_ARG1:%.*]] = ttg.convert_layout %arg1 : tensor<128x2xi8, [[BLOCKED1]]> -> tensor<128x2xi8, [[BLOCKED1]]>
223+
// CHECK: [[A:%.*]] = ttg.upcast_mxfp [[CVT_ARG0]], [[CVT_ARG1]] fp_type = e2m1 : tensor<128x32xi8, #ttg.dot_op<{opIdx = 0, parent = [[DPAS1]], kWidth = 4}>>, tensor<128x2xi8, [[BLOCKED1]]> -> tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>>
224+
// CHECK: [[B:%.*]] = ttg.convert_layout %arg2 : tensor<64x128xbf16, [[BLOCKED2]]> -> tensor<64x128xbf16, #ttg.dot_op<{opIdx = 1, parent = [[DPAS]], kWidth = 2}>>
225+
// CHECK: [[D:%.*]] = tt.dot [[A]], [[B]], [[C]] : tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>> * tensor<64x128xbf16, #ttg.dot_op<{opIdx = 1, parent = [[DPAS]], kWidth = 2}>> -> tensor<128x128xf32, [[DPAS]]>
226+
// CHECK: [[RES:%.*]] = ttg.convert_layout {{.*}} : tensor<128x128xf32, [[DPAS]]> -> tensor<128x128xf32, [[BLOCKED2]]>
227+
228+
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked>
229+
%result = tt.dot_scaled %a scale %scale, %b, %cst lhs = e2m1 rhs = bf16 : tensor<128x32xi8, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<64x128xbf16, #blocked> -> tensor<128x128xf32, #blocked>
230+
tt.return %result : tensor<128x128xf32, #blocked>
231+
}
232+
}

third_party/intel/include/Dialect/TritonIntelGPU/IR/Attributes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33

44
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
55

6+
namespace mlir {
7+
class ModuleOp;
8+
}
9+
610
#define GET_ATTRDEF_CLASSES
711
#include "intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.h.inc"
812

third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ along the row (resp. col) dimension.
7474
);
7575

7676
let extraClassDeclaration = extraDistributedDeclaration # [{
77-
7877
SmallVector<unsigned> getDPASInstShapeA() const;
7978
SmallVector<unsigned> getDPASInstShapeB() const;
8079
SmallVector<unsigned> getDPASInstShapeC() const;
@@ -91,7 +90,30 @@ along the row (resp. col) dimension.
9190
return true;
9291
}
9392

94-
SmallVector<unsigned> getContigPerThread();
93+
SmallVector<unsigned> getContigPerThread() const;
94+
95+
struct DPASCapability {
96+
explicit DPASCapability(unsigned minSGSize) : executionSize(minSGSize) {}
97+
DPASCapability() = default;
98+
99+
bool isPVC() const {
100+
return executionSize == 16;
101+
}
102+
bool isFalconShore() const {
103+
return executionSize == 16;
104+
}
105+
bool isATSM() const {
106+
return executionSize == 8;
107+
}
108+
109+
static constexpr unsigned systolicDepth = 8u;
110+
static constexpr unsigned repeatCount = 8u;
111+
static constexpr unsigned opsChanBitWidths = 32u;
112+
unsigned executionSize = 0u;
113+
};
114+
115+
static DPASCapability getDPASCapability(mlir::ModuleOp mod);
116+
static unsigned getOpsPerChannel(Type elemType);
95117
}];
96118

97119
let hasCustomAssemblyFormat = 1;

third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ SmallVector<unsigned> DpasEncodingAttr::getElemsPerThreadForOperands(
357357
return elemsPerThread;
358358
};
359359

360-
SmallVector<unsigned> DpasEncodingAttr::getContigPerThread() {
360+
SmallVector<unsigned> DpasEncodingAttr::getContigPerThread() const {
361361
size_t rank = getWarpsPerCTA().size();
362362
assert(rank == 2 || rank == 3);
363363
SmallVector<unsigned> contigPerThread(rank, 1);
@@ -381,6 +381,30 @@ SmallVector<unsigned> DpasEncodingAttr::getContigPerThread() {
381381
"be smaller than the threads required per row.");
382382
}
383383

384+
DpasEncodingAttr::DPASCapability
385+
DpasEncodingAttr::getDPASCapability(ModuleOp mod) {
386+
assert(mod && "expected a valid module");
387+
388+
if (auto minSGSizeAttr = mod->getAttrOfType<IntegerAttr>(
389+
triton::gpu::intel::TritonIntelGPUDialect::getMinSGSizeAttrName())) {
390+
unsigned minSGSize = minSGSizeAttr.getInt();
391+
assert(minSGSize == 8 || minSGSize == 16 && "unsupported minSGSize");
392+
return DPASCapability(minSGSize);
393+
}
394+
395+
return DPASCapability();
396+
}
397+
398+
unsigned DpasEncodingAttr::getOpsPerChannel(Type elemType) {
399+
assert(elemType.isIntOrFloat() && "unsupported type for DpasEncodingAttr");
400+
401+
unsigned dpasElemBitWidths = elemType.getIntOrFloatBitWidth();
402+
if (elemType.isFloat8E5M2() || elemType.isFloat8E4M3FN())
403+
dpasElemBitWidths *= 2; // We are upcasting FP8 to FP16.
404+
405+
return DPASCapability::opsChanBitWidths / dpasElemBitWidths;
406+
}
407+
384408
LogicalResult DpasEncodingAttr::verify(
385409
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
386410
unsigned repeatCount, unsigned systolicDepth, unsigned executionSize,
@@ -469,18 +493,14 @@ void DpasEncodingAttr::print(AsmPrinter &printer) const {
469493
llvm::ArrayRef<unsigned> rC = shapeC;
470494
auto warpsPerCTA = getWarpsPerCTA();
471495
auto repCluster = getRepCluster();
472-
printer << "<{"
473-
<< "repeatCount = " << getRepeatCount() << ", "
496+
printer << "<{" << "repeatCount = " << getRepeatCount() << ", "
474497
<< "systolicDepth = " << getSystolicDepth() << ", "
475498
<< "executionSize = " << getExecutionSize() << ", "
476499
<< "opsPerChan = " << getOpsPerChannel() << ", "
477500
<< "threadsPerWarp = " << getSubGroupSize() << ", "
478501
<< "warpsPerCTA = [" << llvm::ArrayRef<unsigned>(warpsPerCTA) << "], "
479-
<< "repCluster = [" << repCluster << "], "
480-
<< "A = [" << rA << "], "
481-
<< "B = [" << rB << "], "
482-
<< "C = [" << rC << "]"
483-
<< "}>";
502+
<< "repCluster = [" << repCluster << "], " << "A = [" << rA << "], "
503+
<< "B = [" << rB << "], " << "C = [" << rC << "]" << "}>";
484504
}
485505

486506
std::optional<LinearLayout>
@@ -553,13 +573,10 @@ Attribute WarpEncodingAttr::parse(AsmParser &parser, Type type) {
553573
void WarpEncodingAttr::print(mlir::AsmPrinter &printer) const {
554574
auto threadsPerWarp = getThreadsPerWarp();
555575
auto sizePerThread = getSizePerThread();
556-
printer << "<{"
557-
<< "sizePerThread = [" << llvm::ArrayRef<unsigned>(sizePerThread)
558-
<< "]"
576+
printer << "<{" << "sizePerThread = ["
577+
<< llvm::ArrayRef<unsigned>(sizePerThread) << "]"
559578
<< ", threadsPerWarp = [" << llvm::ArrayRef<unsigned>(threadsPerWarp)
560-
<< "]"
561-
<< ", order = [" << getOrder() << "]"
562-
<< "}>";
579+
<< "]" << ", order = [" << getOrder() << "]" << "}>";
563580
}
564581

565582
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)