Skip to content
Merged
34 changes: 27 additions & 7 deletions lib/Dialect/TritonGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "intel/include/Dialect/TritonIntelGPU/IR/Attributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Utility.h"
Expand Down Expand Up @@ -114,17 +115,36 @@ LogicalResult UpcastMXFPOp::inferReturnTypes(
retTy = RankedTensorType::get(xShape, FloatType::getBF16(ctx));
} else {
auto oldEncoding = cast<DotOperandEncodingAttr>(encoding);
auto newVEncoding = DotOperandEncodingAttr::get(
ctx, oldEncoding.getOpIdx(), oldEncoding.getParent(),
oldEncoding.getKWidth() * 2);
// Figure out the K dimension for the input A/B, given that the return
// type is upcasted A/B type so we need to update the proper dim size.

const int opIdx = oldEncoding.getOpIdx();
const bool hasBatch = xShape.size() == 3;
const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch;
newShape[kIdx] *= 2;
retTy = RankedTensorType::get(newShape, FloatType::getBF16(ctx),
newVEncoding);
Type elemType = FloatType::getBF16(ctx);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we define this inside the if statement below?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I rather have it here because elemType is used after the if/else at line 147


// Note: For Intel the dot operands layout's kWidth parameter must
// match the parent's DPAS layout opsPerChannel so we need to materialize
// a new DPAS layout.
Attribute newVEncoding;
if (auto dpasEncoding =
dyn_cast<intel::DpasEncodingAttr>(oldEncoding.getParent())) {
auto newDpasEncoding = intel::DpasEncodingAttr::get(
ctx, dpasEncoding.getRepeatCount(), dpasEncoding.getSystolicDepth(),
dpasEncoding.getExecutionSize(),
intel::DpasEncodingAttr::getOpsPerChannel(elemType),
dpasEncoding.getWarpsPerCTA(), dpasEncoding.getRepCluster(),
dpasEncoding.getSubGroupSize());
newVEncoding = DotOperandEncodingAttr::get(
ctx, oldEncoding.getOpIdx(), newDpasEncoding,
newDpasEncoding.getOpsPerChannel());
} else {
// Figure out the K dimension for the input A/B, given that the return
// type is upcasted A/B type so we need to update the proper dim size.
newVEncoding = DotOperandEncodingAttr::get(ctx, oldEncoding.getOpIdx(),
oldEncoding.getParent(),
oldEncoding.getKWidth() * 2);
}
retTy = RankedTensorType::get(newShape, elemType, newVEncoding);
}
inferredReturnTypes.push_back(retTy);
} else {
Expand Down
29 changes: 29 additions & 0 deletions test/TritonIntelGPU/accelerate-matmul-pvc.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,32 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
tt.return
}
}

// -----

#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}>

module attributes {"triton_gpu.target" = "xpu", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 16 : i32, "triton_intel_gpu.min_sg_size" = 16 : i32, "triton_intel_gpu.support_dpas"} {
// CHECK: [[BLOCKED:#.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}>
// CHECK: [[BLOCKED1:#.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
// CHECK: [[BLOCKED2:#.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
// CHECK: [[DPAS:#.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 1], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
// CHECK: [[DPAS1:#.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 4, threadsPerWarp = 16, warpsPerCTA = [4, 1], repCluster = [1, 1], A = [8, 32], B = [32, 16], C = [8, 16]}>
// CHECK: dot_scaled
tt.func @dot_scaled(%a: tensor<128x32xi8, #blocked2>, %scale: tensor<128x2xi8, #blocked1>, %b: tensor<64x128xbf16, #blocked>) -> tensor<128x128xf32, #blocked> {
// CHECK: [[CST:%.*]] = arith.constant dense<0.000000e+00> : tensor<128x128xf32, [[BLOCKED2]]>
// CHECK: [[C:%.*]] = triton_gpu.convert_layout [[CST]] : tensor<128x128xf32, [[BLOCKED2]]> -> tensor<128x128xf32, [[DPAS]]>
// CHECK: [[CVT_ARG0:%.*]] = triton_gpu.convert_layout %arg0 : tensor<128x32xi8, [[BLOCKED]]> -> tensor<128x32xi8, #triton_gpu.dot_op<{opIdx = 0, parent = [[DPAS1]], kWidth = 4}>>
// CHECK: [[CVT_ARG1:%.*]] = triton_gpu.convert_layout %arg1 : tensor<128x2xi8, [[BLOCKED1]]> -> tensor<128x2xi8, [[BLOCKED1]]>
// CHECK: [[A:%.*]] = triton_gpu.upcast_mxfp [[CVT_ARG0]], [[CVT_ARG1]] fp_type = e2m1 : tensor<128x32xi8, #triton_gpu.dot_op<{opIdx = 0, parent = [[DPAS1]], kWidth = 4}>>, tensor<128x2xi8, [[BLOCKED1]]> -> tensor<128x64xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>>
// CHECK: [[B:%.*]] = triton_gpu.convert_layout %arg2 : tensor<64x128xbf16, [[BLOCKED2]]> -> tensor<64x128xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = [[DPAS]], kWidth = 2}>>
// CHECK: [[D:%.*]] = tt.dot [[A]], [[B]], [[C]] : tensor<128x64xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>> * tensor<64x128xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = [[DPAS]], kWidth = 2}>> -> tensor<128x128xf32, [[DPAS]]>
// CHECK: [[RES:%.*]] = triton_gpu.convert_layout {{.*}} : tensor<128x128xf32, [[DPAS]]> -> tensor<128x128xf32, [[BLOCKED2]]>

%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked>
%result = tt.dot_scaled %a scale %scale, %b, %cst lhs = e2m1 rhs = bf16 : tensor<128x32xi8, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<64x128xbf16, #blocked> -> tensor<128x128xf32, #blocked>
tt.return %result : tensor<128x128xf32, #blocked>
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@

#include "triton/Dialect/TritonGPU/IR/Attributes.h"

namespace mlir {
class ModuleOp;
}
Comment on lines +6 to +8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd bet we don't need this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ModuleOp is used in "intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.h.inc" now:

static DPASCapability getDPASCapability(mlir::ModuleOp mod);

That is the reason I have put the forward declaration here.


#define GET_ATTRDEF_CLASSES
#include "intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.h.inc"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ along the row (resp. col) dimension.
);

let extraClassDeclaration = extraDistributedDeclaration # [{

SmallVector<unsigned> getDPASInstShapeA() const;
SmallVector<unsigned> getDPASInstShapeB() const;
SmallVector<unsigned> getDPASInstShapeC() const;
Expand All @@ -91,7 +90,30 @@ along the row (resp. col) dimension.
return true;
}

SmallVector<unsigned> getContigPerThread();
SmallVector<unsigned> getContigPerThread() const;

struct DPASCapability {
explicit DPASCapability(unsigned minSGSize) : executionSize(minSGSize) {}
DPASCapability() = default;

bool isPVC() const {
return executionSize == 16;
}
bool isFalconShore() const {
return executionSize == 16;
}
bool isATSM() const {
return executionSize == 8;
}

static constexpr unsigned systolicDepth = 8u;
static constexpr unsigned repeatCount = 8u;
static constexpr unsigned opsChanBitWidths = 32u;
unsigned executionSize = 0u;
};

static DPASCapability getDPASCapability(mlir::ModuleOp mod);
static unsigned getOpsPerChannel(Type elemType);
}];

let hasCustomAssemblyFormat = 1;
Expand Down
45 changes: 31 additions & 14 deletions third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ SmallVector<unsigned> DpasEncodingAttr::getElemsPerThreadForOperands(
return elemsPerThread;
};

SmallVector<unsigned> DpasEncodingAttr::getContigPerThread() {
SmallVector<unsigned> DpasEncodingAttr::getContigPerThread() const {
size_t rank = getWarpsPerCTA().size();
assert(rank == 2 || rank == 3);
SmallVector<unsigned> contigPerThread(rank, 1);
Expand All @@ -381,6 +381,30 @@ SmallVector<unsigned> DpasEncodingAttr::getContigPerThread() {
"be smaller than the threads required per row.");
}

DpasEncodingAttr::DPASCapability
DpasEncodingAttr::getDPASCapability(ModuleOp mod) {
assert(mod && "expected a valid module");

if (auto minSGSizeAttr = mod->getAttrOfType<IntegerAttr>(
triton::gpu::intel::TritonIntelGPUDialect::getMinSGSizeAttrName())) {
unsigned minSGSize = minSGSizeAttr.getInt();
assert(minSGSize == 8 || minSGSize == 16 && "unsupported minSGSize");
return DPASCapability(minSGSize);
}

return DPASCapability();
}

unsigned DpasEncodingAttr::getOpsPerChannel(Type elemType) {
assert(elemType.isIntOrFloat() && "unsupported type for DpasEncodingAttr");

unsigned dpasElemBitWidths = elemType.getIntOrFloatBitWidth();
if (elemType.isFloat8E5M2() || elemType.isFloat8E4M3FN())
dpasElemBitWidths *= 2; // We are upcasting FP8 to FP16.

return DPASCapability::opsChanBitWidths / dpasElemBitWidths;
}

LogicalResult DpasEncodingAttr::verify(
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
unsigned repeatCount, unsigned systolicDepth, unsigned executionSize,
Expand Down Expand Up @@ -469,18 +493,14 @@ void DpasEncodingAttr::print(AsmPrinter &printer) const {
llvm::ArrayRef<unsigned> rC = shapeC;
auto warpsPerCTA = getWarpsPerCTA();
auto repCluster = getRepCluster();
printer << "<{"
<< "repeatCount = " << getRepeatCount() << ", "
printer << "<{" << "repeatCount = " << getRepeatCount() << ", "
<< "systolicDepth = " << getSystolicDepth() << ", "
<< "executionSize = " << getExecutionSize() << ", "
<< "opsPerChan = " << getOpsPerChannel() << ", "
<< "threadsPerWarp = " << getSubGroupSize() << ", "
<< "warpsPerCTA = [" << llvm::ArrayRef<unsigned>(warpsPerCTA) << "], "
<< "repCluster = [" << repCluster << "], "
<< "A = [" << rA << "], "
<< "B = [" << rB << "], "
<< "C = [" << rC << "]"
<< "}>";
<< "repCluster = [" << repCluster << "], " << "A = [" << rA << "], "
<< "B = [" << rB << "], " << "C = [" << rC << "]" << "}>";
}

std::optional<LinearLayout>
Expand Down Expand Up @@ -553,13 +573,10 @@ Attribute WarpEncodingAttr::parse(AsmParser &parser, Type type) {
void WarpEncodingAttr::print(mlir::AsmPrinter &printer) const {
auto threadsPerWarp = getThreadsPerWarp();
auto sizePerThread = getSizePerThread();
printer << "<{"
<< "sizePerThread = [" << llvm::ArrayRef<unsigned>(sizePerThread)
<< "]"
printer << "<{" << "sizePerThread = ["
<< llvm::ArrayRef<unsigned>(sizePerThread) << "]"
<< ", threadsPerWarp = [" << llvm::ArrayRef<unsigned>(threadsPerWarp)
<< "]"
<< ", order = [" << getOrder() << "]"
<< "}>";
<< "]" << ", order = [" << getOrder() << "]" << "}>";
}

//===----------------------------------------------------------------------===//
Expand Down
Loading
Loading