Skip to content

Commit f7aaf04

Browse files
authored
Add UpcastMXFP op to TritonIntelGPU Dialect to reduce common file changes (#3145)
1 parent 7c63a47 commit f7aaf04

File tree

6 files changed

+162
-34
lines changed

6 files changed

+162
-34
lines changed

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 5 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -392,34 +392,13 @@ UpcastMXFPOp::deduceOutputType(TypedValue<RankedTensorType> inputTensor,
392392
return RankedTensorType::get(xShape, outputElemType);
393393
}
394394

395-
Attribute newVEncoding = nullptr;
396395
auto oldEncoding = cast<DotOperandEncodingAttr>(encoding);
396+
auto newVEncoding = DotOperandEncodingAttr::get(ctx, oldEncoding.getOpIdx(),
397+
oldEncoding.getParent(),
398+
oldEncoding.getKWidth() * 2);
399+
// Figure out the K dimension for the input A/B, given that the return
400+
// type is upcasted A/B type so we need to update the proper dim size.
397401
const int opIdx = oldEncoding.getOpIdx();
398-
// Note: For Intel the dot operands layout's kWidth parameter must match
399-
// the parent's DPAS layout opsPerChannel so we need to materialize a
400-
// new DPAS layout.
401-
if (auto dpasEncoding =
402-
dyn_cast<intel::DpasEncodingAttr>(oldEncoding.getParent())) {
403-
unsigned opsPerChannel =
404-
intel::DpasEncodingAttr::getOpsPerChannel(outputElemType);
405-
// e2m1 is packed 2 elements per int8, we must handle continuous 2
406-
// elements when upcasting to bf16
407-
if (xTy.getElementType() == IntegerType::get(ctx, 8))
408-
opsPerChannel *= 2;
409-
auto newDpasEncoding = intel::DpasEncodingAttr::get(
410-
ctx, dpasEncoding.getRepeatCount(), dpasEncoding.getSystolicDepth(),
411-
dpasEncoding.getExecutionSize(), opsPerChannel,
412-
dpasEncoding.getWarpsPerCTA(), dpasEncoding.getRepCluster(),
413-
product<unsigned>(dpasEncoding.getThreadsPerWarp()));
414-
newVEncoding = DotOperandEncodingAttr::get(
415-
ctx, opIdx, newDpasEncoding, newDpasEncoding.getOpsPerChannel());
416-
} else {
417-
// Figure out the K dimension for the input A/B, given that the return
418-
// type is upcasted A/B type so we need to update the proper dim size.
419-
newVEncoding = DotOperandEncodingAttr::get(ctx, oldEncoding.getOpIdx(),
420-
oldEncoding.getParent(),
421-
oldEncoding.getKWidth() * 2);
422-
}
423402
const bool hasBatch = xShape.size() == 3;
424403
const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch;
425404
newShape[kIdx] *= 2;

test/TritonIntelGPU/accelerate-matmul-pvc.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ module attributes {"ttg.target" = "xpu", "ttg.num-ctas" = 1 : i32, "ttg.num-warp
222222
// CHECK: [[C:%.*]] = ttg.convert_layout [[CST]] : tensor<128x128xf32, [[BLOCKED2]]> -> tensor<128x128xf32, [[DPAS]]>
223223
// CHECK: [[CVT_ARG0:%.*]] = ttg.convert_layout [[ARG0]] : tensor<128x32xi8, [[BLOCKED]]> -> tensor<128x32xi8, #ttg.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>>
224224
// CHECK: [[CVT_ARG1:%.*]] = ttg.convert_layout [[ARG1]] : tensor<128x2xi8, [[BLOCKED1]]> -> tensor<128x2xi8, [[BLOCKED3]]>
225-
// CHECK: [[UPCAST:%.*]] = ttg.upcast_mxfp [[CVT_ARG0]], [[CVT_ARG1]] fp_type = e2m1 {fastMath = false} : tensor<128x32xi8, #ttg.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>>, tensor<128x2xi8, [[BLOCKED3]]> -> tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = [[DPAS1]], kWidth = 4}>>
225+
// CHECK: [[UPCAST:%.*]] = triton_intel_gpu.upcast_mxfp [[CVT_ARG0]], [[CVT_ARG1]] fp_type = e2m1 {fastMath = false} : tensor<128x32xi8, #ttg.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>>, tensor<128x2xi8, [[BLOCKED3]]> -> tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = [[DPAS1]], kWidth = 4}>>
226226
// CHECK: [[A:%.*]] = ttg.convert_layout [[UPCAST]] : tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = [[DPAS1]], kWidth = 4}>> -> tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>>
227227
// CHECK: [[B:%.*]] = ttg.convert_layout [[ARG2]] : tensor<64x128xbf16, [[BLOCKED2]]> -> tensor<64x128xbf16, #ttg.dot_op<{opIdx = 1, parent = [[DPAS]], kWidth = 2}>>
228228
// CHECK: [[D:%.*]] = tt.dot [[A]], [[B]], [[C]] : tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>> * tensor<64x128xbf16, #ttg.dot_op<{opIdx = 1, parent = [[DPAS]], kWidth = 2}>> -> tensor<128x128xf32, [[DPAS]]>
@@ -239,7 +239,7 @@ module attributes {"ttg.target" = "xpu", "ttg.num-ctas" = 1 : i32, "ttg.num-warp
239239
// CHECK: [[C:%.*]] = ttg.convert_layout [[CST]] : tensor<128x128xf32, [[BLOCKED2]]> -> tensor<128x128xf32, [[DPAS]]>
240240
// CHECK: [[CVT_ARG0:%.*]] = ttg.convert_layout %arg0 : tensor<128x32xi8, [[BLOCKED]]> -> tensor<128x32xi8, #ttg.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>>
241241
// CHECK: [[CVT_ARG1:%.*]] = ttg.convert_layout %arg1 : tensor<128x2xi8, [[BLOCKED1]]> -> tensor<128x2xi8, [[BLOCKED3]]>
242-
// CHECK: [[UPCAST:%.*]] = ttg.upcast_mxfp [[CVT_ARG0]], [[CVT_ARG1]] fp_type = e2m1 {fastMath = true} : tensor<128x32xi8, #ttg.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>>, tensor<128x2xi8, [[BLOCKED3]]> -> tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = [[DPAS1]], kWidth = 4}>>
242+
// CHECK: [[UPCAST:%.*]] = triton_intel_gpu.upcast_mxfp [[CVT_ARG0]], [[CVT_ARG1]] fp_type = e2m1 {fastMath = true} : tensor<128x32xi8, #ttg.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>>, tensor<128x2xi8, [[BLOCKED3]]> -> tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = [[DPAS1]], kWidth = 4}>>
243243
// CHECK: [[A:%.*]] = ttg.convert_layout [[UPCAST]] : tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = [[DPAS1]], kWidth = 4}>> -> tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>>
244244
// CHECK: [[CVT_ARG2:%.*]] = ttg.convert_layout [[ARG2]] : tensor<64x128xf8E4M3FN, [[BLOCKED2]]> -> tensor<64x128xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = [[DPAS]], kWidth = 2}>>
245245
// CHECK: [[B:%.*]] = tt.fp_to_fp [[CVT_ARG2]] : tensor<64x128xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = [[DPAS]], kWidth = 2}>> -> tensor<64x128xbf16, #ttg.dot_op<{opIdx = 1, parent = [[DPAS]], kWidth = 2}>>
@@ -285,7 +285,7 @@ module attributes {ttg.target = "xpu", "ttg.num-ctas" = 1 : i32, "ttg.num-warps"
285285
// CHECK: [[C:%.*]] = ttg.convert_layout [[ARG5]] : tensor<32x128xf32, [[BLOCKED4]]> -> tensor<32x128xf32, [[DPAS]]>
286286
// CHECK: [[CVT_ARG1:%.*]] = ttg.convert_layout [[TRANS_B]] : tensor<32x32xi8, [[BLOCKED4]]> -> tensor<32x32xi8, #ttg.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>>
287287
// CHECK: [[CVT_ARG2:%.*]] = ttg.convert_layout [[ARG2]] : tensor<32x2xi8, [[BLOCKED2]]> -> tensor<32x2xi8, [[BLOCKED6]]>
288-
// CHECK: [[UPCAST:%.*]] = ttg.upcast_mxfp [[CVT_ARG1]], [[CVT_ARG2]] fp_type = e2m1 {fastMath = false} : tensor<32x32xi8, #ttg.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>>, tensor<32x2xi8, [[BLOCKED6]]> -> tensor<32x64xbf16, #ttg.dot_op<{opIdx = 0, parent = [[DPAS1]], kWidth = 4}>>
288+
// CHECK: [[UPCAST:%.*]] = triton_intel_gpu.upcast_mxfp [[CVT_ARG1]], [[CVT_ARG2]] fp_type = e2m1 {fastMath = false} : tensor<32x32xi8, #ttg.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>>, tensor<32x2xi8, [[BLOCKED6]]> -> tensor<32x64xbf16, #ttg.dot_op<{opIdx = 0, parent = [[DPAS1]], kWidth = 4}>>
289289
// CHECK: [[A:%.*]] = ttg.convert_layout [[UPCAST]] : tensor<32x64xbf16, #ttg.dot_op<{opIdx = 0, parent = [[DPAS1]], kWidth = 4}>> -> tensor<32x64xbf16, #ttg.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>>
290290
// CHECK: [[CVT_ARG0:%.*]] = ttg.convert_layout [[TRANS_A]] : tensor<64x128xf8E4M3FN, [[BLOCKED5]]> -> tensor<64x128xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = [[DPAS]], kWidth = 2}>>
291291
// CHECK: [[B:%.*]] = tt.fp_to_fp [[CVT_ARG0]] : tensor<64x128xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = [[DPAS]], kWidth = 2}>> -> tensor<64x128xbf16, #ttg.dot_op<{opIdx = 1, parent = [[DPAS]], kWidth = 2}>>

third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUOps.td

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,4 +202,33 @@ def TTIG_SubGroupTransposeOp
202202
let hasVerifier = 1;
203203
}
204204

205+
// The same as ttg.upcast_mxfp, but we want Dot Layout from Dpas layout for input tensor
206+
def TTIG_UpcastMXFPOp : TTIG_Op<"upcast_mxfp", [Pure]> {
207+
let summary = "Convert an mxfp tensor to bf16/fp16";
208+
209+
let hasVerifier = 1;
210+
211+
let description = [{
212+
Compute the bf16 encoded in the given mxfp number as per
213+
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
214+
}];
215+
let arguments = (
216+
ins
217+
TT_Tensor:$src,
218+
TT_Tensor:$scale,
219+
TT_ScaleDotElemTypeAttr:$fp_type,
220+
BoolAttr:$fastMath
221+
);
222+
let results = (outs TT_Tensor:$result);
223+
224+
let assemblyFormat = [{
225+
$src `,` $scale `fp_type` `=` $fp_type attr-dict `:` type($src) `,` type($scale) `->` type($result)
226+
}];
227+
228+
let extraClassDeclaration = [{
229+
static RankedTensorType deduceOutputType(
230+
TypedValue<RankedTensorType> inputTensor, ScaleDotElemType inputElemType, Type outputElemType);
231+
}];
232+
}
233+
205234
#endif

third_party/intel/lib/Dialect/TritonIntelGPU/IR/Ops.cpp

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,4 +209,124 @@ LogicalResult SubGroupTransposeOp::verify() {
209209
return success();
210210
}
211211

212+
LogicalResult UpcastMXFPOp::verify() {
213+
auto fpType = getFpType();
214+
215+
auto xTy = getSrc().getType();
216+
auto scaleTy = getScale().getType();
217+
Builder b(getContext());
218+
if (xTy.getElementType() != b.getBF16Type() &&
219+
xTy.getElementType() != b.getF16Type() &&
220+
xTy.getElementType() != b.getI8Type()) {
221+
return emitOpError(
222+
"element type of the first operand must be bf16/fp16 or i8");
223+
}
224+
225+
if (scaleTy.getElementType() != b.getI8Type()) {
226+
return emitOpError("element type of the second operand must be uint8");
227+
}
228+
229+
auto xShape = xTy.getShape();
230+
auto scaleShape = scaleTy.getShape();
231+
232+
if (xShape.size() != scaleShape.size() || xShape.size() < 2) {
233+
return emitOpError(
234+
"operands must have the same number of dimensions, at least 2");
235+
}
236+
237+
if (!(fpType == ScaleDotElemType::E2M1 || fpType == ScaleDotElemType::E4M3 ||
238+
fpType == ScaleDotElemType::E5M2)) {
239+
return emitOpError("NYI: fpType must be E2M1, E4M3, or E5M2");
240+
}
241+
242+
auto layoutX = xTy.getEncoding();
243+
auto layoutScale = scaleTy.getEncoding();
244+
if (bool(layoutX) != bool(layoutScale)) {
245+
return emitOpError(
246+
"Expected either both or neither operands to have an encoding");
247+
}
248+
// Nothing to check if no encoding. This is used to infer the return type in
249+
// AccelerateMatmul.cpp
250+
if (!layoutX) {
251+
return success();
252+
}
253+
254+
auto dotEncoding = dyn_cast<DotOperandEncodingAttr>(layoutX);
255+
if (!dotEncoding) {
256+
return emitOpError("Expected a DotOperandEncodingAttr for values");
257+
}
258+
if (!isa<BlockedEncodingAttr, LinearEncodingAttr>(layoutScale)) {
259+
return emitOpError(
260+
"Expected a BlockOperandEncoding or LinearOperandEncoding "
261+
"for scales");
262+
}
263+
264+
// Change to support fp8 types
265+
const auto elemsPacked = fpType == ScaleDotElemType::E2M1 ? 2 : 1;
266+
// Figure out the K dimension for the input A/B. For A/B scale, the K
267+
// dimension is always the last dimension.
268+
const int opIdx = dotEncoding.getOpIdx();
269+
const bool hasBatch = xShape.size() == 3;
270+
const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch;
271+
272+
if (xShape[kIdx] != (32 / elemsPacked) * scaleShape.back()) {
273+
return emitOpError("K dimension of first operand must be 16 times "
274+
"larger than last/K dimension of the second operand");
275+
}
276+
277+
// Check other dimensions match too. For input A/B, we need to figure out the
278+
// index for the M/N dimension. For scale, it's always {(batch), M/N, K}.
279+
const int mnIdx = (opIdx == 0 ? 0 : 1) + hasBatch;
280+
if (hasBatch && xShape[0] != scaleShape[0])
281+
return emitOpError("batch dimension must match between operands");
282+
if (xShape[mnIdx] != scaleShape[hasBatch]) {
283+
return emitOpError("M/N dimension must match between operands");
284+
}
285+
286+
return success();
287+
}
288+
289+
RankedTensorType
290+
UpcastMXFPOp::deduceOutputType(TypedValue<RankedTensorType> inputTensor,
291+
ScaleDotElemType inputElemType,
292+
Type outputElemType) {
293+
MLIRContext *ctx = inputTensor.getContext();
294+
auto xTy = inputTensor.getType();
295+
if (inputElemType != ScaleDotElemType::E2M1)
296+
return xTy;
297+
298+
auto xShape = xTy.getShape();
299+
auto newShape = llvm::to_vector(xShape);
300+
auto encoding = xTy.getEncoding();
301+
if (!encoding) {
302+
newShape.back() *= 2;
303+
return RankedTensorType::get(xShape, outputElemType);
304+
}
305+
306+
auto oldEncoding = cast<DotOperandEncodingAttr>(encoding);
307+
const int opIdx = oldEncoding.getOpIdx();
308+
// Note: For Intel the dot operands layout's kWidth parameter must match
309+
// the parent's DPAS layout opsPerChannel so we need to materialize a
310+
// new DPAS layout.
311+
auto dpasEncoding = cast<intel::DpasEncodingAttr>(oldEncoding.getParent());
312+
unsigned opsPerChannel =
313+
intel::DpasEncodingAttr::getOpsPerChannel(outputElemType);
314+
// e2m1 is packed 2 elements per int8, we must handle continuous 2
315+
// elements when upcasting to bf16
316+
if (xTy.getElementType() == IntegerType::get(ctx, 8))
317+
opsPerChannel *= 2;
318+
auto newDpasEncoding = intel::DpasEncodingAttr::get(
319+
ctx, dpasEncoding.getRepeatCount(), dpasEncoding.getSystolicDepth(),
320+
dpasEncoding.getExecutionSize(), opsPerChannel,
321+
dpasEncoding.getWarpsPerCTA(), dpasEncoding.getRepCluster(),
322+
product<unsigned>(dpasEncoding.getThreadsPerWarp()));
323+
Attribute newVEncoding = DotOperandEncodingAttr::get(
324+
ctx, opIdx, newDpasEncoding, newDpasEncoding.getOpsPerChannel());
325+
326+
const bool hasBatch = xShape.size() == 3;
327+
const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch;
328+
newShape[kIdx] *= 2;
329+
return RankedTensorType::get(newShape, outputElemType, newVEncoding);
330+
}
331+
212332
} // namespace mlir::triton::gpu::intel

third_party/intel/lib/TritonIntelGPUToLLVM/UpcastMXFPToLLVM.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
using namespace mlir;
1414
using namespace mlir::triton;
15-
using namespace mlir::triton::gpu;
15+
using namespace mlir::triton::gpu::intel;
1616

1717
namespace {
1818

@@ -80,7 +80,7 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern<UpcastMXFPOp> {
8080
// kWidth here is the contiguous number of elements each thread access.
8181
unsigned kWidth = dpasEnc.getOpsPerChannel() / 2;
8282
unsigned numMxfp =
83-
TritonGPUDialect::TritonGPUDialect::getThreadsPerWarp(mod) / instShapeM;
83+
triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod) / instShapeM;
8484
unsigned mxfpSize = repSize * subTileSize * kWidth;
8585
constexpr unsigned numScales = 16;
8686

third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -423,10 +423,10 @@ class DecomposeScaledBlocked : public OpRewritePattern<tt::DotScaledOp> {
423423
if (!scale)
424424
return v;
425425

426-
auto retTy = triton::gpu::UpcastMXFPOp::deduceOutputType(
426+
auto retTy = triton::gpu::intel::UpcastMXFPOp::deduceOutputType(
427427
v, elemType, Builder(v.getContext()).getBF16Type());
428-
return rewriter.create<ttg::UpcastMXFPOp>(v.getLoc(), retTy, v, scale,
429-
elemType, fastMath);
428+
return rewriter.create<ttgi::UpcastMXFPOp>(v.getLoc(), retTy, v, scale,
429+
elemType, fastMath);
430430
}
431431
};
432432

0 commit comments

Comments
 (0)