|
4 | 4 | #include "triton/Dialect/Triton/IR/Utility.h" |
5 | 5 | #include "triton/Dialect/TritonGPU/IR/Attributes.h" |
6 | 6 | #include "triton/Dialect/TritonGPU/IR/Dialect.h" |
| 7 | +#include "triton/Tools/Sys/GetEnv.hpp" |
7 | 8 |
|
8 | 9 | #define GET_OP_CLASSES |
9 | 10 | #include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc" |
@@ -50,15 +51,21 @@ LogicalResult UpcastMXFPOp::verify() { |
50 | 51 | return success(); |
51 | 52 | } |
52 | 53 |
|
| 54 | + /// TODO: Temporarily disabled this check to allow for the blocked encoding. |
| 55 | + /// Enable once we have the dot op encoding UpcastMXFPOp lowering. |
53 | 56 | auto dotEncoding = dyn_cast<DotOperandEncodingAttr>(layoutX); |
54 | | - if (!dotEncoding) { |
| 57 | + if (mlir::triton::tools::getBoolEnv( |
| 58 | + "TRITON_INTEL_UPCASTMXFP_DOTOP_ENCODING") && |
| 59 | + !dotEncoding) { |
55 | 60 | return emitOpError("Expected a DotOperandEncodingAttr for values"); |
56 | 61 | } |
57 | 62 | if (!isa<BlockedEncodingAttr, LinearEncodingAttr>(layoutScale)) { |
58 | 63 | return emitOpError( |
59 | 64 | "Expected a BlockOperandEncoding or LinearOperandEncoding " |
60 | 65 | "for scales"); |
61 | 66 | } |
| 67 | + if (!dotEncoding) |
| 68 | + return success(); |
62 | 69 |
|
63 | 70 | if (isa<NvidiaMmaEncodingAttr>(dotEncoding.getParent())) { |
64 | 71 | // Necessary to keep all of the scales of a given block of values in the |
@@ -114,34 +121,45 @@ LogicalResult UpcastMXFPOp::inferReturnTypes( |
114 | 121 | newShape.back() *= 2; |
115 | 122 | retTy = RankedTensorType::get(xShape, FloatType::getBF16(ctx)); |
116 | 123 | } else { |
117 | | - auto oldEncoding = cast<DotOperandEncodingAttr>(encoding); |
118 | | - |
119 | | - const int opIdx = oldEncoding.getOpIdx(); |
120 | | - const bool hasBatch = xShape.size() == 3; |
121 | | - const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch; |
122 | | - newShape[kIdx] *= 2; |
123 | 124 | Type elemType = FloatType::getBF16(ctx); |
124 | | - |
125 | | - // Note: For Intel the dot operands layout's kWidth parameter must match |
126 | | - // the parent's DPAS layout opsPerChannel so we need to materialize a new |
127 | | - // DPAS layout. |
128 | | - Attribute newVEncoding; |
129 | | - if (auto dpasEncoding = |
130 | | - dyn_cast<intel::DpasEncodingAttr>(oldEncoding.getParent())) { |
131 | | - auto newDpasEncoding = intel::DpasEncodingAttr::get( |
132 | | - ctx, dpasEncoding.getRepeatCount(), dpasEncoding.getSystolicDepth(), |
133 | | - dpasEncoding.getExecutionSize(), |
134 | | - intel::DpasEncodingAttr::getOpsPerChannel(elemType), |
135 | | - dpasEncoding.getWarpsPerCTA(), dpasEncoding.getRepCluster(), |
136 | | - dpasEncoding.getSubGroupSize()); |
137 | | - newVEncoding = DotOperandEncodingAttr::get( |
138 | | - ctx, opIdx, newDpasEncoding, newDpasEncoding.getOpsPerChannel()); |
139 | | - } else { |
140 | | - // Figure out the K dimension for the input A/B, given that the return |
141 | | - // type is upcasted A/B type so we need to update the proper dim size. |
142 | | - newVEncoding = DotOperandEncodingAttr::get(ctx, oldEncoding.getOpIdx(), |
143 | | - oldEncoding.getParent(), |
144 | | - oldEncoding.getKWidth() * 2); |
| 125 | + Attribute newVEncoding = nullptr; |
| 126 | + if (auto oldEncoding = dyn_cast<DotOperandEncodingAttr>(encoding)) { |
| 127 | + const int opIdx = oldEncoding.getOpIdx(); |
| 128 | + const bool hasBatch = xShape.size() == 3; |
| 129 | + const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch; |
| 130 | + newShape[kIdx] *= 2; |
| 131 | + |
| 132 | + // Note: For Intel the dot operands layout's kWidth parameter must match |
| 133 | + // the parent's DPAS layout opsPerChannel so we need to materialize a |
| 134 | + // new DPAS layout. |
| 135 | + if (auto dpasEncoding = |
| 136 | + dyn_cast<intel::DpasEncodingAttr>(oldEncoding.getParent())) { |
| 137 | + auto newDpasEncoding = intel::DpasEncodingAttr::get( |
| 138 | + ctx, dpasEncoding.getRepeatCount(), |
| 139 | + dpasEncoding.getSystolicDepth(), dpasEncoding.getExecutionSize(), |
| 140 | + intel::DpasEncodingAttr::getOpsPerChannel(elemType), |
| 141 | + dpasEncoding.getWarpsPerCTA(), dpasEncoding.getRepCluster(), |
| 142 | + dpasEncoding.getSubGroupSize()); |
| 143 | + newVEncoding = DotOperandEncodingAttr::get( |
| 144 | + ctx, opIdx, newDpasEncoding, newDpasEncoding.getOpsPerChannel()); |
| 145 | + } else { |
| 146 | + // Figure out the K dimension for the input A/B, given that the return |
| 147 | + // type is upcasted A/B type so we need to update the proper dim size. |
| 148 | + newVEncoding = DotOperandEncodingAttr::get( |
| 149 | + ctx, oldEncoding.getOpIdx(), oldEncoding.getParent(), |
| 150 | + oldEncoding.getKWidth() * 2); |
| 151 | + } |
| 152 | + } else if (auto oldEncoding = dyn_cast<BlockedEncodingAttr>(encoding)) { |
| 153 | + // TODO: Temporary code, remove once upcast_mxfp support dot encoding. |
| 154 | + assert(!tools::getBoolEnv("TRITON_INTEL_UPCASTMXFP_DOTOP_ENCODING")); |
| 155 | + SmallVector<unsigned> sizePerThread = oldEncoding.getSizePerThread(); |
| 156 | + int opIdx = sizePerThread.back() == 1 ? 1 : 0; |
| 157 | + sizePerThread[!opIdx] *= 2; |
| 158 | + newShape[!opIdx] *= 2; |
| 159 | + newVEncoding = BlockedEncodingAttr::get( |
| 160 | + ctx, sizePerThread, oldEncoding.getThreadsPerWarp(), |
| 161 | + oldEncoding.getWarpsPerCTA(), oldEncoding.getCTAOrder(), |
| 162 | + oldEncoding.getCTALayout()); |
145 | 163 | } |
146 | 164 | retTy = RankedTensorType::get(newShape, elemType, newVEncoding); |
147 | 165 | } |
|
0 commit comments