Skip to content

Commit 464d1f1

Browse files
authored
[AMD] Revert the AMD path of #5475 (#5911)
AMD was seeing big regressions after that PR. There might be a couple things that might need adjustment, so revertin the AMD changes for now until AMD folks have the bandwidth to investigate.
1 parent 79a8a3b commit 464d1f1

File tree

8 files changed

+688
-227
lines changed

8 files changed

+688
-227
lines changed

third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,34 @@ def BufferAtomicRMWOp : TT_AMDGPU_Op<"buffer_atomic_rmw", [
266266
}];
267267
}
268268

269+
def TTG_UpcastMXFPOp : TT_AMDGPU_Op<"upcast_mxfp", [Pure]> {
270+
let summary = "Convert an mxfp tensor to bf16/fp16";
271+
272+
let hasVerifier = 1;
273+
274+
let description = [{
275+
Compute the bf16 encoded in the given mxfp number as per
276+
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
277+
}];
278+
let arguments = (
279+
ins
280+
TT_Tensor:$src,
281+
TT_Tensor:$scale,
282+
TT_ScaleDotElemTypeAttr:$fp_type,
283+
BoolAttr:$fastMath
284+
);
285+
let results = (outs TT_Tensor:$result);
286+
287+
let assemblyFormat = [{
288+
$src `,` $scale `fp_type` `=` $fp_type attr-dict `:` type($src) `,` type($scale) `->` type($result)
289+
}];
290+
291+
let extraClassDeclaration = [{
292+
static RankedTensorType deduceOutputType(
293+
TypedValue<RankedTensorType> inputTensor, ScaleDotElemType inputElemType, Type outputElemType);
294+
}];
295+
}
296+
269297
def BufferStoreOp : TT_AMDGPU_Op<"buffer_store", [
270298
SameLoadStoreOperandsEncoding,
271299
MemoryEffects<[MemWrite<GlobalMemory>]>,

third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,4 +133,112 @@ LogicalResult ExtractSliceOp::verify() {
133133

134134
return success();
135135
}
136+
137+
LogicalResult UpcastMXFPOp::verify() {
138+
auto fpType = getFpType();
139+
140+
auto xTy = getSrc().getType();
141+
auto scaleTy = getScale().getType();
142+
Builder b(getContext());
143+
if (xTy.getElementType() != b.getBF16Type() &&
144+
xTy.getElementType() != b.getF16Type() &&
145+
xTy.getElementType() != b.getI8Type()) {
146+
return emitOpError(
147+
"element type of the first operand must be bf16/fp16 or i8");
148+
}
149+
150+
if (scaleTy.getElementType() != b.getI8Type()) {
151+
return emitOpError("element type of the second operand must be uint8");
152+
}
153+
154+
auto xShape = xTy.getShape();
155+
auto scaleShape = scaleTy.getShape();
156+
157+
if (xShape.size() != scaleShape.size() || xShape.size() < 2) {
158+
return emitOpError(
159+
"operands must have the same number of dimensions, at least 2");
160+
}
161+
162+
if (!(fpType == ScaleDotElemType::E2M1 || fpType == ScaleDotElemType::E4M3 ||
163+
fpType == ScaleDotElemType::E5M2)) {
164+
return emitOpError("NYI: fpType must be E2M1, E4M3, or E5M2");
165+
}
166+
167+
auto layoutX = xTy.getEncoding();
168+
auto layoutScale = scaleTy.getEncoding();
169+
if (bool(layoutX) != bool(layoutScale)) {
170+
return emitOpError(
171+
"Expected either both or neither operands to have an encoding");
172+
}
173+
// Nothing to check if no encoding. This is used to infer the return type in
174+
// AccelerateMatmul.cpp
175+
if (!layoutX) {
176+
return success();
177+
}
178+
179+
auto dotEncoding = dyn_cast<gpu::DotOperandEncodingAttr>(layoutX);
180+
if (!dotEncoding) {
181+
return emitOpError("Expected a DotOperandEncodingAttr for values");
182+
}
183+
if (!isa<gpu::BlockedEncodingAttr, gpu::LinearEncodingAttr>(layoutScale)) {
184+
return emitOpError(
185+
"Expected a BlockOperandEncoding or LinearOperandEncoding "
186+
"for scales");
187+
}
188+
189+
// Change to support fp8 types
190+
const auto elemsPacked = fpType == ScaleDotElemType::E2M1 ? 2 : 1;
191+
// Figure out the K dimension for the input A/B. For A/B scale, the K
192+
// dimension is always the last dimension.
193+
const int opIdx = dotEncoding.getOpIdx();
194+
const bool hasBatch = xShape.size() == 3;
195+
const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch;
196+
197+
if (xShape[kIdx] != (32 / elemsPacked) * scaleShape.back()) {
198+
return emitOpError("K dimension of first operand must be 16 times "
199+
"larger than last/K dimension of the second operand");
200+
}
201+
202+
// Check other dimensions match too. For input A/B, we need to figure out the
203+
// index for the M/N dimension. For scale, it's always {(batch), M/N, K}.
204+
const int mnIdx = (opIdx == 0 ? 0 : 1) + hasBatch;
205+
if (hasBatch && xShape[0] != scaleShape[0])
206+
return emitOpError("batch dimension must match between operands");
207+
if (xShape[mnIdx] != scaleShape[hasBatch]) {
208+
return emitOpError("M/N dimension must match between operands");
209+
}
210+
211+
return success();
212+
}
213+
214+
RankedTensorType
215+
UpcastMXFPOp::deduceOutputType(TypedValue<RankedTensorType> inputTensor,
216+
ScaleDotElemType inputElemType,
217+
Type outputElemType) {
218+
MLIRContext *ctx = inputTensor.getContext();
219+
auto xTy = inputTensor.getType();
220+
if (inputElemType != ScaleDotElemType::E2M1)
221+
return xTy;
222+
223+
auto xShape = xTy.getShape();
224+
auto newShape = llvm::to_vector(xShape);
225+
auto encoding = xTy.getEncoding();
226+
if (!encoding) {
227+
newShape.back() *= 2;
228+
return RankedTensorType::get(xShape, outputElemType);
229+
}
230+
231+
auto oldEncoding = cast<DotOperandEncodingAttr>(encoding);
232+
auto newVEncoding = DotOperandEncodingAttr::get(ctx, oldEncoding.getOpIdx(),
233+
oldEncoding.getParent(),
234+
oldEncoding.getKWidth() * 2);
235+
// Figure out the K dimension for the input A/B, given that the return
236+
// type is upcasted A/B type so we need to update the proper dim size.
237+
const int opIdx = oldEncoding.getOpIdx();
238+
const bool hasBatch = xShape.size() == 3;
239+
const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch;
240+
newShape[kIdx] *= 2;
241+
return RankedTensorType::get(newShape, outputElemType, newVEncoding);
242+
}
243+
136244
} // namespace mlir::triton::amdgpu

third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ add_triton_library(TritonAMDGPUToLLVM
2222
OptimizeLDSUtility.cpp
2323
SPMDOpToLLVM.cpp
2424
SchedInstructions.cpp
25-
Fp4ToFpOpToLLVM.cpp
25+
UpcastMXFPToLLVM.cpp
2626

2727
DEPENDS
2828
TritonAMDGPUConversionPassIncGen

third_party/amd/lib/TritonAMDGPUToLLVM/Fp4ToFpOpToLLVM.cpp

Lines changed: 0 additions & 212 deletions
This file was deleted.

third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,10 @@ void populateTritonAMDGPUToLLVMPatterns(LLVMTypeConverter &typeConverter,
3838
RewritePatternSet &patterns,
3939
PatternBenefit benefit);
4040

41-
void populateFp4ToFpToLLVMPatterns(LLVMTypeConverter &typeConverter,
42-
RewritePatternSet &patterns,
43-
PatternBenefit benefit);
41+
void populateUpcastMXFPToLLVMPatterns(LLVMTypeConverter &typeConverter,
42+
RewritePatternSet &patterns,
43+
const TargetInfo &targetInfo,
44+
PatternBenefit benefit);
4445

4546
} // namespace mlir::triton::AMD
4647

third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,8 @@ struct ConvertTritonAMDGPUToLLVM
201201

202202
mlir::triton::AMD::populateTritonAMDGPUToLLVMPatterns(typeConverter,
203203
patterns, AMDBenefit);
204-
mlir::triton::AMD::populateFp4ToFpToLLVMPatterns(typeConverter, patterns,
205-
AMDBenefit);
204+
mlir::triton::AMD::populateUpcastMXFPToLLVMPatterns(typeConverter, patterns,
205+
targetInfo, AMDBenefit);
206206

207207
// TODO(thomas): this should probably be done in a separate step to not
208208
// interfere with our own lowering of arith ops. Add arith/math's patterns

0 commit comments

Comments
 (0)