66
77#include " PatternTritonGPUOpToLLVM.h"
88
9+ #include " TritonNVIDIAGPUToLLVM/PTXAsmFormat.h"
910#include " mlir/IR/Value.h"
1011#include " mlir/IR/ValueRange.h"
1112#include " mlir/Transforms/DialectConversion.h"
@@ -19,6 +20,73 @@ using namespace mlir;
1920using namespace mlir ::triton;
2021using namespace mlir ::triton::gpu;
2122
23+ // Convert 8 fp4 elements packed into a 32bit reg into 8 bf16 elements packed
24+ // into 4 32bits regs.
25+ static constexpr const char *ptxAsm =
26+ " {\n "
27+ " .reg .b32 a<14>;\n "
28+ " and.b32 a0, $4, -2004318072;\n\t "
29+ " shr.u32 a1, a0, 3;\n\t "
30+ " and.b32 a2, $4, 2004318071;\n\t "
31+ " shr.u32 a3, a2, 16;\n\t "
32+ " shr.u32 a4, a0, 19;\n\t "
33+ " prmt.b32 a5, -1065353216, -1065336832, a2;\n\t "
34+ " prmt.b32 a6, -1065353216, -1065336832, a3;\n\t "
35+ " prmt.b32 a7, 1061109504, 1077952576, a2;\n\t "
36+ " prmt.b32 a8, 1061109504, 1077952576, a3;\n\t "
37+ " prmt.b32 a9, 32768, 0, a1;\n\t "
38+ " prmt.b32 a10, 32768, 0, a4;\n\t "
39+ " or.b32 a11, a7, a9;\n\t "
40+ " or.b32 a12, a8, a10;\n\t "
41+ " prmt.b32 $0, a5, a11, 20800;\n\t "
42+ " prmt.b32 $1, a5, a11, 29538;\n\t "
43+ " prmt.b32 $2, a6, a12, 20800;\n\t "
44+ " prmt.b32 $3, a6, a12, 29538;\n\t "
45+ " }" ;
46+
47+ static Value createInlineAsmUpcast (Location loc, RewriterBase &rewriter,
48+ Type retType, Value packedVec) {
49+ PTXBuilder builder;
50+ SmallVector<PTXBuilder::Operand *> operands;
51+ for (int i = 0 ; i < 4 ; i++) {
52+ operands.push_back (builder.newOperand (" =r" ));
53+ }
54+ operands.push_back (builder.newOperand (packedVec, " r" ));
55+ auto &ptxOp = *builder.create (ptxAsm);
56+ ptxOp (operands, /* onlyAttachMLIRArgs=*/ true );
57+ Value result = builder.launch (rewriter, loc, retType, false );
58+ return result;
59+ }
60+
61+ static SmallVector<Value> convertMxfp4x2ToBf16x2PTX (RewriterBase &rewriter,
62+ Location loc,
63+ ArrayRef<Value> values) {
64+ SmallVector<Value> results;
65+ MLIRContext *ctx = rewriter.getContext ();
66+ assert (values.size () % 4 == 0 );
67+ for (int i = 0 ; i < values.size (); i += 4 ) {
68+ Value v0 = values[i];
69+ Value v1 = values[i + 1 ];
70+ Value v2 = values[i + 2 ];
71+ Value v3 = values[i + 3 ];
72+ Value packedVec = undef (vec_ty (i8_ty, 4 ));
73+ packedVec = insert_element (packedVec, v0, i32_val (0 ));
74+ packedVec = insert_element (packedVec, v1, i32_val (1 ));
75+ packedVec = insert_element (packedVec, v2, i32_val (2 ));
76+ packedVec = insert_element (packedVec, v3, i32_val (3 ));
77+ SmallVector<Type> rets (4 , i32_ty);
78+ Type retType = struct_ty (rets);
79+ Value ret = createInlineAsmUpcast (loc, rewriter, retType, packedVec);
80+ for (int i = 0 ; i < 4 ; i++) {
81+ Value extractI32 = extract_val (ret, i);
82+ Value vecbf16 = bitcast (extractI32, vec_ty (bf16_ty, 2 ));
83+ results.push_back (extract_element (vecbf16, i32_val (0 )));
84+ results.push_back (extract_element (vecbf16, i32_val (1 )));
85+ }
86+ }
87+ return results;
88+ }
89+
2290namespace {
2391class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern <UpcastMXFPOp> {
2492private:
@@ -53,7 +121,7 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern<UpcastMXFPOp> {
53121 cast<DotOperandEncodingAttr>(op.getType ().getEncoding ()).getKWidth ();
54122
55123 if (fpType == ScaleDotElemType::E2M1)
56- xVals = LLVM::convertMxfp4x2ToBf16x2 (rewriter, loc, xVals);
124+ xVals = convertMxfp4x2ToBf16x2PTX (rewriter, loc, xVals);
57125
58126 // Each thread owns elements of 4 mxfp vectors so we need 4 scales
59127 // Since we go from a threadShape of 8x4 to 16x2, we let c = tid / 4 * 2
0 commit comments