Skip to content

Commit 2c8f3e2

Browse files
authored
Vectorize TritonGEN::FToTf32Op (#4811)
This change reduces `.llir` file one of the largest Flex Attn kernels by ~16k lines of code: 79094 -> 63276. It seems vectorization on our side is necessary if we want to reduce compilation time. --------- Signed-off-by: Anatoly Myachev <[email protected]>
1 parent c01b6ec commit 2c8f3e2

File tree

4 files changed

+27
-27
lines changed

4 files changed

+27
-27
lines changed

test/Conversion/intel/tritongpu_to_gen_dot.mlir

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -76,15 +76,15 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
7676
// CHECK-SAME: %[[A:.*]]: !llvm.struct<(f32, f32, f32, f32)>, %[[B:.*]]: !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>, %[[C:.*]]: !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>, %[[PTR_1:.*]]: !llvm.ptr<1>) attributes {intel_reqd_sub_group_size = 16 : i32, reqd_work_group_size = array<i32: 16, 1, 1>} {
7777
tt.func @dot_f32_tf32_tf32_f32_1(%a: tensor<8x8xf32, #dot_operand_a>, %b: tensor<8x16xf32, #dot_operand_b>, %c: tensor<8x16xf32, #dpas>) {
7878
// COM: To simplify, only check RTNE and its usage for the last element of A, B, C
79-
// CHECK: %[[A_LAST_VAL:.*]] = llvm.extractvalue %[[A]][3]
80-
// CHECK: %[[A_RTNE_VAL:.*]] = llvm.call spir_funccc @_Z25__spirv_RoundFToTF32INTELf(%[[A_LAST_VAL]])
81-
// CHECK: %[[A_0:.*]] = llvm.insertelement %[[A_RTNE_VAL]], %{{.*}}{{\[}}%{{.*}} : i32] : vector<4xf32>
82-
// CHECK: %[[B_LAST_VAL:.*]] = llvm.extractvalue %[[B]][7]
83-
// CHECK: %[[B_RTNE_VAL:.*]] = llvm.call spir_funccc @_Z25__spirv_RoundFToTF32INTELf(%[[B_LAST_VAL]])
84-
// CHECK: %[[B_0:.*]] = llvm.insertelement %[[B_RTNE_VAL]], %{{.*}}{{\[}}%{{.*}} : i32] : vector<8xf32>
79+
// CHECK: %[[A_EXTR_LAST_VAL:.*]] = llvm.extractvalue %[[A]][3]
80+
// CHECK: %[[A_LAST_VAL:.*]] = llvm.insertelement %[[A_EXTR_LAST_VAL]], %{{.*}} : vector<4xf32>
81+
// CHECK: %[[A_RTNE_VAL:.*]] = llvm.call spir_funccc @_Z25__spirv_RoundFToTF32INTELDv4_f(%[[A_LAST_VAL]])
82+
// CHECK: %[[B_EXTR_LAST_VAL:.*]] = llvm.extractvalue %[[B]][7]
83+
// CHECK: %[[B_LAST_VAL:.*]] = llvm.insertelement %[[B_EXTR_LAST_VAL]], %{{.*}} : vector<8xf32>
84+
// CHECK: %[[B_RTNE_VAL:.*]] = llvm.call spir_funccc @_Z25__spirv_RoundFToTF32INTELDv8_f(%[[B_LAST_VAL]])
8585
// CHECK: %[[C_LAST_VAL:.*]] = llvm.extractvalue %[[C]][7]
8686
// CHECK: %[[C_0:.*]] = llvm.insertelement %[[C_LAST_VAL]], %{{.*}}{{\[}}%{{.*}} : i32] : vector<8xf32>
87-
// CHECK: llvm.call spir_funccc @_Z45__spirv_SubgroupMatrixMultiplyAccumulateINTELiDv4_fDv8_fS0_i(%{{.*}}, %[[A_0]], %[[B_0]], %[[C_0]], %{{.*}}} : (i32, vector<4xf32>, vector<8xf32>, vector<8xf32>, i32) -> vector<8xf32>
87+
// CHECK: llvm.call spir_funccc @_Z45__spirv_SubgroupMatrixMultiplyAccumulateINTELiDv4_fDv8_fS0_i(%{{.*}}, %[[A_RTNE_VAL]], %[[B_RTNE_VAL]], %[[C_0]], %{{.*}}} : (i32, vector<4xf32>, vector<8xf32>, vector<8xf32>, i32) -> vector<8xf32>
8888
%0 = tt.dot %a, %b, %c, inputPrecision = tf32 : tensor<8x8xf32, #dot_operand_a> * tensor<8x16xf32, #dot_operand_b> -> tensor<8x16xf32, #dpas>
8989
tt.return
9090
}

third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -394,8 +394,8 @@ def TritonGEN_FToTf32Op
394394
a 32-bit floating point type to TF32 with rounding to the nearest even.
395395
}];
396396

397-
let arguments = (ins F32:$val);
398-
let results = (outs F32:$res);
397+
let arguments = (ins LLVM_ScalarOrVectorOf<F32>:$val);
398+
let results = (outs LLVM_ScalarOrVectorOf<F32>:$res);
399399
let assemblyFormat = [{
400400
$val attr-dict `:` type($val)
401401
}];

third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -923,13 +923,16 @@ struct TritonFToTf32OpLowering
923923
auto b = TritonLLVMOpBuilder(loc, rewriter);
924924

925925
Value value = op->getOperand(0);
926-
SmallVector<Type> argTypes{f32_ty};
926+
Type valueType = value.getType();
927+
928+
SmallVector<Type> argTypes{valueType};
927929
SmallVector<Value> args{value};
928930

929-
const StringLiteral funcName = "_Z25__spirv_RoundFToTF32INTELf";
930-
auto retType = f32_ty;
931+
std::string fnName = "__spirv_RoundFToTF32INTEL";
932+
fnName = intel::mangle(fnName, argTypes);
933+
auto retType = valueType;
931934
auto callOp = intel::createDeviceFunctionCall(
932-
rewriter, funcName, retType, {argTypes}, {args}, {},
935+
rewriter, fnName, retType, {argTypes}, {args}, {},
933936
intel::noUnwindWillReturnAttrs);
934937
rewriter.replaceOp(op, callOp);
935938
return success();

third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -352,21 +352,18 @@ class DotOpDPASConversionHelper {
352352
for (int repInner = 0; repInner < repClusterInner; ++repInner) {
353353
Value matVal = rewriter.create<LLVM::UndefOp>(loc, dotOpTy);
354354
for (int k = 0; k < numElemsPerOperand; ++k) {
355-
if (isFToTF32Enabled) {
356-
Value f32Val = elems[offset++];
357-
auto t32Val =
358-
rewriter.create<TritonGEN::FToTf32Op>(loc, f32Val)
359-
.getResult();
360-
matVal =
361-
tb.insert_element(dotOpTy, matVal, t32Val, tb.i32_val(k));
362-
363-
} else {
364-
matVal = tb.insert_element(dotOpTy, matVal, elems[offset++],
365-
tb.i32_val(k));
366-
}
355+
matVal = tb.insert_element(dotOpTy, matVal, elems[offset++],
356+
tb.i32_val(k));
357+
}
358+
if (isFToTF32Enabled) {
359+
auto t32Val = rewriter.create<TritonGEN::FToTf32Op>(loc, matVal)
360+
.getResult();
361+
vals[{b, i * repClusterOuter + repOuter,
362+
j * repClusterInner + repInner}] = t32Val;
363+
} else {
364+
vals[{b, i * repClusterOuter + repOuter,
365+
j * repClusterInner + repInner}] = matVal;
367366
}
368-
vals[{b, i * repClusterOuter + repOuter,
369-
j * repClusterInner + repInner}] = matVal;
370367
}
371368
}
372369
}

0 commit comments

Comments
 (0)