|
5 | 5 | #include "triton/Dialect/TritonGPU/IR/Attributes.h" |
6 | 6 | #include "triton/Dialect/TritonGPU/IR/Dialect.h" |
7 | 7 | #include "triton/Dialect/TritonGPU/Transforms/Utility.h" |
8 | | -#include "triton/Tools/Sys/GetEnv.hpp" |
9 | 8 |
|
10 | 9 | #define GET_OP_CLASSES |
11 | 10 | #include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc" |
@@ -339,21 +338,15 @@ LogicalResult UpcastMXFPOp::verify() { |
339 | 338 | return success(); |
340 | 339 | } |
341 | 340 |
|
342 | | - /// TODO: Temporarily disabled this check to allow for the blocked encoding. |
343 | | - /// Enable once we have the dot op encoding UpcastMXFPOp lowering. |
344 | 341 | auto dotEncoding = dyn_cast<DotOperandEncodingAttr>(layoutX); |
345 | | - if (mlir::triton::tools::getBoolEnv( |
346 | | - "TRITON_INTEL_UPCASTMXFP_DOTOP_ENCODING") && |
347 | | - !dotEncoding) { |
| 342 | + if (!dotEncoding) { |
348 | 343 | return emitOpError("Expected a DotOperandEncodingAttr for values"); |
349 | 344 | } |
350 | 345 | if (!isa<BlockedEncodingAttr, LinearEncodingAttr>(layoutScale)) { |
351 | 346 | return emitOpError( |
352 | 347 | "Expected a BlockOperandEncoding or LinearOperandEncoding " |
353 | 348 | "for scales"); |
354 | 349 | } |
355 | | - if (!dotEncoding) |
356 | | - return success(); |
357 | 350 |
|
358 | 351 | if (isa<NvidiaMmaEncodingAttr>(dotEncoding.getParent())) { |
359 | 352 | // Necessary to keep all of the scales of a given block of values in the |
@@ -411,43 +404,36 @@ LogicalResult UpcastMXFPOp::inferReturnTypes( |
411 | 404 | } else { |
412 | 405 | Type elemType = FloatType::getBF16(ctx); |
413 | 406 | Attribute newVEncoding = nullptr; |
414 | | - if (auto oldEncoding = dyn_cast<DotOperandEncodingAttr>(encoding)) { |
415 | | - const int opIdx = oldEncoding.getOpIdx(); |
416 | | - const bool hasBatch = xShape.size() == 3; |
417 | | - const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch; |
418 | | - newShape[kIdx] *= 2; |
419 | | - |
420 | | - // Note: For Intel the dot operands layout's kWidth parameter must match |
421 | | - // the parent's DPAS layout opsPerChannel so we need to materialize a |
422 | | - // new DPAS layout. |
423 | | - if (auto dpasEncoding = |
424 | | - dyn_cast<intel::DpasEncodingAttr>(oldEncoding.getParent())) { |
425 | | - auto newDpasEncoding = intel::DpasEncodingAttr::get( |
426 | | - ctx, dpasEncoding.getRepeatCount(), |
427 | | - dpasEncoding.getSystolicDepth(), dpasEncoding.getExecutionSize(), |
428 | | - intel::DpasEncodingAttr::getOpsPerChannel(elemType), |
429 | | - dpasEncoding.getWarpsPerCTA(), dpasEncoding.getRepCluster(), |
430 | | - product<unsigned>(dpasEncoding.getThreadsPerWarp())); |
431 | | - newVEncoding = DotOperandEncodingAttr::get( |
432 | | - ctx, opIdx, newDpasEncoding, newDpasEncoding.getOpsPerChannel()); |
433 | | - } else { |
434 | | - // Figure out the K dimension for the input A/B, given that the return |
435 | | - // type is upcasted A/B type so we need to update the proper dim size. |
436 | | - newVEncoding = DotOperandEncodingAttr::get( |
437 | | - ctx, oldEncoding.getOpIdx(), oldEncoding.getParent(), |
438 | | - oldEncoding.getKWidth() * 2); |
439 | | - } |
440 | | - } else if (auto oldEncoding = dyn_cast<BlockedEncodingAttr>(encoding)) { |
441 | | - // TODO: Temporary code, remove once upcast_mxfp support dot encoding. |
442 | | - assert(!tools::getBoolEnv("TRITON_INTEL_UPCASTMXFP_DOTOP_ENCODING")); |
443 | | - SmallVector<unsigned> sizePerThread = oldEncoding.getSizePerThread(); |
444 | | - int opIdx = sizePerThread.back() == 1 ? 1 : 0; |
445 | | - sizePerThread[!opIdx] *= 2; |
446 | | - newShape[!opIdx] *= 2; |
447 | | - newVEncoding = BlockedEncodingAttr::get( |
448 | | - ctx, sizePerThread, oldEncoding.getThreadsPerWarp(), |
449 | | - oldEncoding.getWarpsPerCTA(), oldEncoding.getCTAOrder(), |
450 | | - oldEncoding.getCTALayout()); |
| 407 | + auto oldEncoding = cast<DotOperandEncodingAttr>(encoding); |
| 408 | + const int opIdx = oldEncoding.getOpIdx(); |
| 409 | + const bool hasBatch = xShape.size() == 3; |
| 410 | + const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch; |
| 411 | + newShape[kIdx] *= 2; |
| 412 | + |
| 413 | + // Note: For Intel the dot operands layout's kWidth parameter must match |
| 414 | + // the parent's DPAS layout opsPerChannel so we need to materialize a |
| 415 | + // new DPAS layout. |
| 416 | + if (auto dpasEncoding = |
| 417 | + dyn_cast<intel::DpasEncodingAttr>(oldEncoding.getParent())) { |
| 418 | + unsigned opsPerChannel = |
| 419 | + intel::DpasEncodingAttr::getOpsPerChannel(elemType); |
| 420 | + // e2m1 is packed 2 elements per int8, we must handle continuous 2 |
| 421 | + // elements when upcasting to bf16 |
| 422 | + if (xTy.getElementType() == IntegerType::get(ctx, 8)) |
| 423 | + opsPerChannel *= 2; |
| 424 | + auto newDpasEncoding = intel::DpasEncodingAttr::get( |
| 425 | + ctx, dpasEncoding.getRepeatCount(), dpasEncoding.getSystolicDepth(), |
| 426 | + dpasEncoding.getExecutionSize(), opsPerChannel, |
| 427 | + dpasEncoding.getWarpsPerCTA(), dpasEncoding.getRepCluster(), |
| 428 | + product<unsigned>(dpasEncoding.getThreadsPerWarp())); |
| 429 | + newVEncoding = DotOperandEncodingAttr::get( |
| 430 | + ctx, opIdx, newDpasEncoding, newDpasEncoding.getOpsPerChannel()); |
| 431 | + } else { |
| 432 | + // Figure out the K dimension for the input A/B, given that the return |
| 433 | + // type is upcasted A/B type so we need to update the proper dim size. |
| 434 | + newVEncoding = DotOperandEncodingAttr::get(ctx, oldEncoding.getOpIdx(), |
| 435 | + oldEncoding.getParent(), |
| 436 | + oldEncoding.getKWidth() * 2); |
451 | 437 | } |
452 | 438 | retTy = RankedTensorType::get(newShape, elemType, newVEncoding); |
453 | 439 | } |
|
0 commit comments