Skip to content

Commit 48e661d

Browse files
Merge commit '8966e5c8d3397e9e3c4100fd1075c715ee6629b9'
2 parents 8bb917f + 8966e5c commit 48e661d

File tree

30 files changed

+772
-46
lines changed

30 files changed

+772
-46
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ if(TRITON_BUILD_PYTHON_MODULE)
211211
MLIRSCFToControlFlow
212212
MLIRIndexToLLVM
213213
MLIRGPUToROCDLTransforms
214+
MLIRUBToLLVM
214215

215216
# LLVM
216217
LLVMPasses

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,16 +125,18 @@ using namespace mlir::triton;
125125
#define array_ty(elemTy, count) LLVM::LLVMArrayType::get(elemTy, count)
126126

127127
// Constants
128+
#define int_val(bitwidth, val) \
129+
LLVM::createLLVMIntegerConstant(rewriter, loc, bitwidth, val)
128130
#define i1_val(val) LLVM::createConstantI1(loc, rewriter, val)
129131
#define true_val() i1_val(true)
130132
#define false_val() i1_val(false)
131133
#define f16_val(...) LLVM::createConstantF16(loc, rewriter, __VA_ARGS__)
132134
#define f32_val(...) LLVM::createConstantF32(loc, rewriter, __VA_ARGS__)
133135
#define f64_val(...) LLVM::createConstantF64(loc, rewriter, __VA_ARGS__)
136+
#define i8_val(val) int_val(8, val)
137+
#define i16_val(val) int_val(16, val)
134138
#define i32_val(...) LLVM::createConstantI32(loc, rewriter, __VA_ARGS__)
135139
#define i64_val(...) LLVM::createConstantI64(loc, rewriter, __VA_ARGS__)
136-
#define int_val(width, val) \
137-
LLVM::createLLVMIntegerConstant(rewriter, loc, width, val)
138140
#define tid_val() getThreadId(rewriter, loc)
139141

140142
// Attributes

include/triton/Dialect/Triton/IR/Traits.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,12 @@ class DotLike : public TraitBase<ConcreteType, DotLike> {
8181
if (aShape.size() != bShape.size() || aShape.size() != cShape.size())
8282
return op->emitOpError("expected all operands to have the same rank");
8383
// Check if the first two operands share a common dimension
84-
if (aShape[aShape.size() - 1] != bShape[aShape.size() - 2])
85-
return op->emitOpError("expected the last dimension of the first operand "
86-
"to be equal to the second-to-last dimension of "
87-
"the second operand");
84+
// TODO: enable back with an interface to support scaled dot.
85+
// if (aShape[aShape.size() - 1] != bShape[aShape.size() - 2])
86+
// return op->emitOpError("expected the last dimension of the first
87+
// operand "
88+
// "to be equal to the second-to-last dimension of
89+
// " "the second operand");
8890
// Check the batch dimension
8991
if (aShape.size() == 3 &&
9092
(aShape[0] != cShape[0] || bShape[0] != cShape[0]))

include/triton/Dialect/Triton/IR/TritonAttrDefs.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,4 +119,18 @@ def TT_InputPrecisionAttr : I32EnumAttr<
119119
let cppNamespace = "::mlir::triton";
120120
}
121121

122+
// Type for F8F6F4 kind of floats.
123+
def TT_F8F6F4TypeAttr : I32EnumAttr<
124+
"F8F6F4Type", "",
125+
[
126+
I32EnumAttrCase<"E4M3", 0, "e4m3">,
127+
I32EnumAttrCase<"E5M2", 1, "e5m2">,
128+
I32EnumAttrCase<"E2M3", 2, "e2m3">,
129+
I32EnumAttrCase<"E3M2", 3, "e3m2">,
130+
I32EnumAttrCase<"E2M1", 4, "e2m1">
131+
132+
]>{
133+
let cppNamespace = "::mlir::triton";
134+
}
135+
122136
#endif

include/triton/Dialect/Triton/IR/TritonDialect.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ def Triton_Dialect : Dialect {
2828
"arith::ArithDialect",
2929
"math::MathDialect",
3030
"scf::SCFDialect",
31-
"cf::ControlFlowDialect"
31+
"cf::ControlFlowDialect",
32+
"ub::UBDialect"
3233
];
3334

3435
let extraClassDeclaration = [{

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,43 @@ def TT_DotOp : TT_Op<"dot", [Pure,
673673
let hasVerifier = 1;
674674
}
675675

676+
677+
//
678+
// DotScaled Op
679+
//
680+
def TT_DotScaledOp : TT_Op<"dot_scaled", [Pure,
681+
DotLike,
682+
TypesMatchWith<"result's type matches accumulator's type",
683+
"d", "c", "$_self">]> {
684+
let summary = "dot_scaled";
685+
686+
let description = [{
687+
$d = matrix_multiply(scale($lhs, $lhs_scale), scale($rhs, $rhs_scale)) + $c.
688+
Where scale(x, s) is a function that applies the scale per block following microscaling spec.
689+
}];
690+
691+
let arguments = (
692+
ins
693+
// inputs are integer types as they are packed types and we currently
694+
// don't have a representation for those.
695+
TT_IntTensor:$lhs,
696+
TT_IntTensor:$rhs,
697+
TT_FloatTensor:$c,
698+
TT_IntTensor:$lhs_scale,
699+
Optional<TT_IntTensor>:$rhs_scale,
700+
TT_F8F6F4TypeAttr:$lhs_type,
701+
TT_F8F6F4TypeAttr:$rhs_type
702+
);
703+
704+
let results = (outs TT_FloatTensor:$d);
705+
706+
// Not sure why I need to fully specify the optional group, but otherwise it complains when loading the mlir file
707+
let assemblyFormat = [{
708+
$lhs `,` $lhs_scale `,` $rhs (`,`) : (`,` $rhs_scale^ `,`)? $c `lhs` `=` $lhs_type `rhs` `=` $rhs_type attr-dict
709+
`:` type($lhs) `,` type($lhs_scale) `*` type($rhs) (`,` type($rhs_scale)^)? `->` type($d)
710+
}];
711+
}
712+
676713
//
677714
// Reduce Op
678715
//

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -710,7 +710,7 @@ for
710710
// starting from the contiguous dimension
711711
for (unsigned d = 0; d < rank - 1; ++d) {
712712
unsigned i = order[d];
713-
unsigned threadsPerCTA = std::clamp<unsigned>(remainingThreads, 1, shapePerCTA[i] / sizePerThread[i]);
713+
unsigned threadsPerCTA = std::clamp<unsigned>(remainingThreads, 1, std::max<unsigned>(1, shapePerCTA[i] / sizePerThread[i]));
714714
threadsPerWarp[i] = std::clamp<unsigned>(threadsPerCTA, 1, remainingLanes);
715715
warpsPerCTA[i] = std::clamp<unsigned>(threadsPerCTA / threadsPerWarp[i], 1, remainingWarps);
716716
remainingWarps /= warpsPerCTA[i];
@@ -743,7 +743,7 @@ for
743743
// starting from the most strided dimension
744744
for (int d = rank - 1; d >= 0; --d) {
745745
unsigned i = order[d];
746-
CTAsPerCGA[i] = std::clamp<unsigned>(remainingCTAs, 1, shape[i] / sizePerThread[i]);
746+
CTAsPerCGA[i] = std::clamp<unsigned>(remainingCTAs, 1, std::max<unsigned>(1, shape[i] / sizePerThread[i]));
747747
CTASplitNum[i] = CTAsPerCGA[i];
748748
remainingCTAs /= CTAsPerCGA[i];
749749
}

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,4 +256,24 @@ def TTG_LocalStoreOp : TTG_Op<"local_store", [DeclareOpInterfaceMethods<MemoryEf
256256
}];
257257
}
258258

259+
def TTG_UpcastMXFPOp : TTG_Op<"upcast_mxfp", [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
260+
let summary = "Convert an mxfp tensor to bf16";
261+
262+
let hasVerifier = 1;
263+
264+
let description = [{
265+
Compute the bf16 encoded in the given mxfp number as per
266+
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
267+
}];
268+
let arguments = (ins
269+
TT_Tensor:$src,
270+
TT_Tensor:$scale,
271+
TT_F8F6F4TypeAttr:$fp_type);
272+
let results = (outs TT_Tensor:$result);
273+
274+
let assemblyFormat = [{
275+
$src `,` $scale `fp_type` `=` $fp_type attr-dict `:` type($src) `,` type($scale) `->` type($result)
276+
}];
277+
}
278+
259279
#endif

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@ class SharedEncodingAttr;
2828
// Version = 3: <m, n, k>
2929
SmallVector<unsigned, 3> mmaVersionToInstrShape(int version,
3030
const ArrayRef<int64_t> &shape,
31-
RankedTensorType type,
32-
int numWarps);
31+
Type type, int numWarps);
3332

3433
// Return true if the Load uses block pointer.
3534
bool isLoadFromTensorPtr(triton::LoadOp op);

lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -553,8 +553,9 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
553553
GenericOpPattern<triton::ExperimentalDescriptorStoreOp>,
554554
GenericOpPattern<triton::ExperimentalTensormapCreateOp>,
555555
GenericOpPattern<triton::ExperimentalTensormapFenceproxyAcquireOp>,
556-
GenericOpPattern<triton::CallOp>, TritonFuncOpPattern>(typeConverter,
557-
context);
556+
// this assumes the right layout will be set later for dot scaled.
557+
GenericOpPattern<triton::DotScaledOp>, GenericOpPattern<triton::CallOp>,
558+
TritonFuncOpPattern>(typeConverter, context);
558559
}
559560

560561
//

0 commit comments

Comments
 (0)