Skip to content

Commit 5bab216

Browse files
Merge commit '3bac3be56609c8f7286a244d4622ea72a2fc4402'
2 parents c74534d + 3bac3be commit 5bab216

File tree

55 files changed

+969
-665
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+969
-665
lines changed

bin/RegisterTritonDialects.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
6868
mlir::triton::registerTritonGPUGlobalScratchAllocationPass();
6969
mlir::triton::registerConvertTritonGPUToLLVMPass();
7070
mlir::triton::registerConvertNVGPUToLLVMPass();
71-
mlir::triton::registerDecomposeUnsupportedNVIDIAConversions();
7271
mlir::registerLLVMDIScope();
7372
mlir::triton::gpu::intel::registerTritonAnnotateModulePass();
7473
mlir::triton::gpu::intel::registerTritonIntelGPUPasses();

include/triton/Dialect/Triton/IR/TritonAttrDefs.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,8 @@ def TT_ScaleDotElemTypeAttr : I32EnumAttr<
128128
I32EnumAttrCase<"E2M3", 2, "e2m3">,
129129
I32EnumAttrCase<"E3M2", 3, "e3m2">,
130130
I32EnumAttrCase<"E2M1", 4, "e2m1">,
131-
I32EnumAttrCase<"BF16", 5, "bf16">
132-
131+
I32EnumAttrCase<"BF16", 5, "bf16">,
132+
I32EnumAttrCase<"FP16", 6, "fp16">
133133
]>{
134134
let cppNamespace = "::mlir::triton";
135135
}

include/triton/Dialect/Triton/IR/Utility.h

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -167,27 +167,6 @@ template <typename VecT> bool isConsecutive(const VecT &vec) {
167167
return isConsecutive(ArrayRef(vec));
168168
}
169169

170-
// LLVM's STLExtras.h provides a bunch of functions that work over ranges, but
171-
// it's missing min/max_element until
172-
// https://github.com/llvm/llvm-project/commit/fab2bb8b makes it into Triton.
173-
// TODO(jlebar): Remove this once we have the LLVM helpers.
174-
template <typename R> auto min_element(R &&Range) {
175-
return std::min_element(llvm::adl_begin(Range), llvm::adl_end(Range));
176-
}
177-
template <typename R, typename Compare>
178-
auto min_element(R &&Range, Compare &&C) {
179-
return std::min_element(llvm::adl_begin(Range), llvm::adl_end(Range),
180-
std::forward<Compare>(C));
181-
}
182-
template <typename R> auto max_element(R &&Range) {
183-
return std::max_element(llvm::adl_begin(Range), llvm::adl_end(Range));
184-
}
185-
template <typename R, typename T, typename Compare>
186-
auto max_element(R &&Range, Compare &&C) {
187-
return std::max_element(llvm::adl_begin(Range), llvm::adl_end(Range),
188-
std::forward<Compare>(C));
189-
}
190-
191170
} // namespace triton
192171
} // namespace mlir
193172

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -283,8 +283,8 @@ def TTG_LocalStoreOp : TTG_Op<"local_store", [DeclareOpInterfaceMethods<MemoryEf
283283
}];
284284
}
285285

286-
def TTG_UpcastMXFPOp : TTG_Op<"upcast_mxfp", [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
287-
let summary = "Convert an mxfp tensor to bf16";
286+
def TTG_UpcastMXFPOp : TTG_Op<"upcast_mxfp", [Pure]> {
287+
let summary = "Convert an mxfp tensor to bf16/fp16";
288288

289289
let hasVerifier = 1;
290290

@@ -301,6 +301,11 @@ def TTG_UpcastMXFPOp : TTG_Op<"upcast_mxfp", [Pure, DeclareOpInterfaceMethods<In
301301
let assemblyFormat = [{
302302
$src `,` $scale `fp_type` `=` $fp_type attr-dict `:` type($src) `,` type($scale) `->` type($result)
303303
}];
304+
305+
let extraClassDeclaration = [{
306+
static RankedTensorType deduceOutputType(
307+
TypedValue<RankedTensorType> inputTensor, ScaleDotElemType inputElemType, Type outputElemType);
308+
}];
304309
}
305310

306311
// Allocate global memory

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
3333
"TRITON_HIP_USE_BLOCK_PINGPONG",
3434
"TRITON_LLVM_DEBUG_ONLY",
3535
"TRITON_ENABLE_ASAN",
36-
"TRITON_OVERRIDE_NV_CAPABILITY",
36+
"TRITON_OVERRIDE_ARCH",
3737
"USE_IR_LOC",
3838
"NVPTX_ENABLE_DUMP",
3939
"TRITON_INTEL_ADVANCED_PATH",

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -356,10 +356,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
356356
auto srcTy = op.getSrc().getType();
357357
auto dstTy = op.getType();
358358

359-
// TODO (Keren): Currently, we handle general mma/blocked/slice/dot(ampere)
360-
// -> mma/blocked/slice/dot(ampere) conversions. The following tasks must be
361-
// completed before we can remove the layoutIsOK check:
362-
// 1. Support for AMD's WMMA dot operand
363359
std::function<bool(Attribute)> layoutIsOK = [&](Attribute layout) {
364360
if (isa<MmaEncodingTrait>(layout)) {
365361
return !useLegacyMMAConversion;
@@ -368,15 +364,11 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
368364
if (isa<MmaEncodingTrait>(dotOperand.getParent())) {
369365
return !useLegacyMMAConversion;
370366
}
371-
return false;
372-
}
373-
if (isa<BlockedEncodingAttr, LinearEncodingAttr>(layout)) {
374-
return true;
375367
}
376368
if (auto slice = dyn_cast<SliceEncodingAttr>(layout)) {
377369
return layoutIsOK(slice.getParent());
378370
}
379-
return false;
371+
return true;
380372
};
381373
if (!layoutIsOK(srcTy.getEncoding()) || !layoutIsOK(dstTy.getEncoding())) {
382374
return failure();

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 52 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -303,13 +303,15 @@ LogicalResult UpcastMXFPOp::verify() {
303303

304304
auto xTy = getSrc().getType();
305305
auto scaleTy = getScale().getType();
306-
307-
if (xTy.getElementType() != FloatType::getBF16(getContext()) &&
308-
xTy.getElementType() != IntegerType::get(getContext(), 8)) {
309-
return emitOpError("element type of the first operand must be bf16 or i8");
306+
Builder b(getContext());
307+
if (xTy.getElementType() != b.getBF16Type() &&
308+
xTy.getElementType() != b.getF16Type() &&
309+
xTy.getElementType() != b.getI8Type()) {
310+
return emitOpError(
311+
"element type of the first operand must be bf16/fp16 or i8");
310312
}
311313

312-
if (scaleTy.getElementType() != IntegerType::get(getContext(), 8)) {
314+
if (scaleTy.getElementType() != b.getI8Type()) {
313315
return emitOpError("element type of the second operand must be uint8");
314316
}
315317

@@ -383,66 +385,55 @@ LogicalResult UpcastMXFPOp::verify() {
383385
return success();
384386
}
385387

386-
LogicalResult UpcastMXFPOp::inferReturnTypes(
387-
MLIRContext *ctx, std::optional<Location> loc, ValueRange operands,
388-
DictionaryAttr attributes, OpaqueProperties opaqueProperties,
389-
RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
390-
auto xTy = cast<RankedTensorType>(operands[0].getType());
391-
auto properties = opaqueProperties.as<const Properties *>();
392-
auto typeEncoded = properties->fp_type.getValue();
393-
auto xShape = xTy.getShape();
388+
RankedTensorType
389+
UpcastMXFPOp::deduceOutputType(TypedValue<RankedTensorType> inputTensor,
390+
ScaleDotElemType inputElemType,
391+
Type outputElemType) {
392+
MLIRContext *ctx = inputTensor.getContext();
393+
auto xTy = inputTensor.getType();
394+
if (inputElemType != ScaleDotElemType::E2M1)
395+
return xTy;
394396

397+
auto xShape = xTy.getShape();
398+
auto newShape = llvm::to_vector(xShape);
395399
auto encoding = xTy.getEncoding();
396-
397-
if (typeEncoded == ScaleDotElemType::E2M1) {
398-
RankedTensorType retTy;
399-
400-
auto newShape = SmallVector<int64_t>(xShape);
401-
if (!encoding) {
402-
newShape.back() *= 2;
403-
retTy = RankedTensorType::get(xShape, FloatType::getBF16(ctx));
404-
} else {
405-
Type elemType = FloatType::getBF16(ctx);
406-
Attribute newVEncoding = nullptr;
407-
auto oldEncoding = cast<DotOperandEncodingAttr>(encoding);
408-
const int opIdx = oldEncoding.getOpIdx();
409-
const bool hasBatch = xShape.size() == 3;
410-
const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch;
411-
newShape[kIdx] *= 2;
412-
413-
// Note: For Intel the dot operands layout's kWidth parameter must match
414-
// the parent's DPAS layout opsPerChannel so we need to materialize a
415-
// new DPAS layout.
416-
if (auto dpasEncoding =
417-
dyn_cast<intel::DpasEncodingAttr>(oldEncoding.getParent())) {
418-
unsigned opsPerChannel =
419-
intel::DpasEncodingAttr::getOpsPerChannel(elemType);
420-
// e2m1 is packed 2 elements per int8, we must handle continuous 2
421-
// elements when upcasting to bf16
422-
if (xTy.getElementType() == IntegerType::get(ctx, 8))
423-
opsPerChannel *= 2;
424-
auto newDpasEncoding = intel::DpasEncodingAttr::get(
425-
ctx, dpasEncoding.getRepeatCount(), dpasEncoding.getSystolicDepth(),
426-
dpasEncoding.getExecutionSize(), opsPerChannel,
427-
dpasEncoding.getWarpsPerCTA(), dpasEncoding.getRepCluster(),
428-
product<unsigned>(dpasEncoding.getThreadsPerWarp()));
429-
newVEncoding = DotOperandEncodingAttr::get(
430-
ctx, opIdx, newDpasEncoding, newDpasEncoding.getOpsPerChannel());
431-
} else {
432-
// Figure out the K dimension for the input A/B, given that the return
433-
// type is upcasted A/B type so we need to update the proper dim size.
434-
newVEncoding = DotOperandEncodingAttr::get(ctx, oldEncoding.getOpIdx(),
435-
oldEncoding.getParent(),
436-
oldEncoding.getKWidth() * 2);
437-
}
438-
retTy = RankedTensorType::get(newShape, elemType, newVEncoding);
439-
}
440-
inferredReturnTypes.push_back(retTy);
400+
if (!encoding) {
401+
newShape.back() *= 2;
402+
return RankedTensorType::get(xShape, outputElemType);
403+
}
404+
405+
Attribute newVEncoding = nullptr;
406+
auto oldEncoding = cast<DotOperandEncodingAttr>(encoding);
407+
const int opIdx = oldEncoding.getOpIdx();
408+
// Note: For Intel the dot operands layout's kWidth parameter must match
409+
// the parent's DPAS layout opsPerChannel so we need to materialize a
410+
// new DPAS layout.
411+
if (auto dpasEncoding =
412+
dyn_cast<intel::DpasEncodingAttr>(oldEncoding.getParent())) {
413+
unsigned opsPerChannel =
414+
intel::DpasEncodingAttr::getOpsPerChannel(outputElemType);
415+
// e2m1 is packed 2 elements per int8, we must handle continuous 2
416+
// elements when upcasting to bf16
417+
if (xTy.getElementType() == IntegerType::get(ctx, 8))
418+
opsPerChannel *= 2;
419+
auto newDpasEncoding = intel::DpasEncodingAttr::get(
420+
ctx, dpasEncoding.getRepeatCount(), dpasEncoding.getSystolicDepth(),
421+
dpasEncoding.getExecutionSize(), opsPerChannel,
422+
dpasEncoding.getWarpsPerCTA(), dpasEncoding.getRepCluster(),
423+
product<unsigned>(dpasEncoding.getThreadsPerWarp()));
424+
newVEncoding = DotOperandEncodingAttr::get(
425+
ctx, opIdx, newDpasEncoding, newDpasEncoding.getOpsPerChannel());
441426
} else {
442-
inferredReturnTypes.push_back(xTy);
427+
// Figure out the K dimension for the input A/B, given that the return
428+
// type is upcasted A/B type so we need to update the proper dim size.
429+
newVEncoding = DotOperandEncodingAttr::get(ctx, oldEncoding.getOpIdx(),
430+
oldEncoding.getParent(),
431+
oldEncoding.getKWidth() * 2);
443432
}
444-
445-
return success();
433+
const bool hasBatch = xShape.size() == 3;
434+
const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch;
435+
newShape[kIdx] *= 2;
436+
return RankedTensorType::get(newShape, outputElemType, newVEncoding);
446437
}
447438

448439
OpFoldResult MemDescTransOp::fold(FoldAdaptor adaptor) {

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -573,8 +573,10 @@ class DecomposeScaledBlocked
573573
maybeWithEncoding(scale.getType(), scaleEncoding);
574574
scale = rewriter.create<ConvertLayoutOp>(scale.getLoc(),
575575
newScaleDotElemType, scale);
576-
ret = rewriter.create<triton::gpu::UpcastMXFPOp>(v.getLoc(), ret, scale,
577-
type);
576+
auto retTy = triton::gpu::UpcastMXFPOp::deduceOutputType(
577+
ret, type, Builder(v.getContext()).getBF16Type());
578+
ret = rewriter.create<triton::gpu::UpcastMXFPOp>(v.getLoc(), retTy, ret,
579+
scale, type);
578580
}
579581
return ret;
580582
}

lib/Dialect/TritonGPU/Transforms/LoopScheduling.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ CoarseSchedule scheduleKeyOps(scf::ForOp forOp,
115115
}
116116

117117
auto stages = llvm::make_second_range(opToStage);
118-
int maxStage = *std::max_element(stages.begin(), stages.end());
118+
int maxStage = *llvm::max_element(stages);
119119
CoarseSchedule schedule(maxStage + 1);
120120
SmallVector<CoarseSchedule::Cluster> clusters(maxStage + 1);
121121
for (int i = 0; i <= maxStage; i++) {

lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,8 +257,7 @@ DenseMap<Operation *, int> assignLatencies(ModuleOp moduleOp,
257257

258258
// Calculate the stage distance between applicable loads.
259259
auto vals = llvm::make_second_range(loadOpToIndLevel);
260-
int maxIndirectionLevel =
261-
vals.empty() ? 0 : *std::max_element(vals.begin(), vals.end());
260+
int maxIndirectionLevel = vals.empty() ? 0 : *llvm::max_element(vals);
262261
unsigned loadLatency = (numStages - 1) / (maxIndirectionLevel + 1);
263262

264263
for (auto [loadOp, dist] : loadOpToIndLevel) {

0 commit comments

Comments
 (0)