Skip to content

Commit 21b69a0

Browse files
Improved performance of the fp4tofp conversion (#4440)
Use a simple lookup table instead of explicit conversion. Use bitwise operations on vectors to build the indices. Fixes #4298 ------------ This is another (and perhaps overcomplicated) version of #4299, that uses bitwise operations on vectors (not per element) to build the indices. The resulting llir is the following: ```MLIR %idx_vec0 = lshr <16 x i8> %i8vec, splat (i8 4) %idx_vec1 = and <16 x i8> %i8vec, splat (i8 15) %idx0 = extractelement <16 x i8> %idx_vec0, i64 0 %bf0 = extractelement <16 x bfloat> <bfloat 0xR0000, bfloat 0xR3F00, bfloat 0xR3F80, bfloat 0xR3FC0, bfloat 0xR4000, bfloat 0xR4040, bfloat 0xR4080, bfloat 0xR40C0, bfloat 0xR8000, bfloat 0xRBF00, bfloat 0xRBF80, bfloat 0xRBFC0, bfloat 0xRC000, bfloat 0xRC040, bfloat 0xRC080, bfloat 0xRC0C0>, i8 %idx0 %idx1 = extractelement <16 x i8> %idx_vec1, i64 0 %bf1 = extractelement <16 x bfloat> <bfloat 0xR0000, bfloat 0xR3F00, bfloat 0xR3F80, bfloat 0xR3FC0, bfloat 0xR4000, bfloat 0xR4040, bfloat 0xR4080, bfloat 0xR40C0, bfloat 0xR8000, bfloat 0xRBF00, bfloat 0xRBF80, bfloat 0xRBFC0, bfloat 0xRC000, bfloat 0xRC040, bfloat 0xRC080, bfloat 0xRC0C0>, i8 %idx1 ... ``` --------- Signed-off-by: Tiotto, Ettore <[email protected]> Co-authored-by: Tiotto, Ettore <[email protected]>
1 parent ab774ad commit 21b69a0

File tree

2 files changed

+355
-33
lines changed

2 files changed

+355
-33
lines changed

test/TritonIntelGPU/fp4tofp.mlir

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
// RUN: triton-opt %s -split-input-file --convert-triton-intel-gpu-to-llvm -canonicalize | FileCheck %s
2+
3+
#blocked = #ttg.blocked<{sizePerThread = [16], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
4+
#blocked1 = #ttg.blocked<{sizePerThread = [32], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
5+
module attributes {triton_intel_gpu.support_bf16_conversion, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block, triton_intel_gpu.target_arch = "spir16", ttg.global_scratch_memory_alignment = 1 : i32, ttg.global_scratch_memory_size = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 16384 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 32 : i32} {
6+
tt.func public @fp4_to_bf16_kernel(%arg0: !tt.ptr<i8>, %arg1: !tt.ptr<bf16>) {
7+
%c1_i64 = arith.constant 1 : i64
8+
%c16_i32 = arith.constant 16 : i32
9+
%c16_i64 = arith.constant 16 : i64
10+
%c32_i32 = arith.constant 32 : i32
11+
%c32_i64 = arith.constant 32 : i64
12+
%0 = tt.get_program_id x : i32
13+
%1 = arith.muli %0, %c16_i32 : i32
14+
%2 = tt.make_tensor_ptr %arg0, [%c16_i64], [%c1_i64], [%1] {order = array<i32: 0>, tt.divisibility = dense<16> : tensor<1xi32>} : <tensor<16xi8, #blocked>>
15+
%3 = tt.load %2 : !tt.ptr<tensor<16xi8, #blocked>>
16+
%4 = ttg.fp4_to_fp %3 {axis = 0 : i32} : tensor<16xi8, #blocked> -> tensor<32xbf16, #blocked1>
17+
%5 = arith.muli %0, %c32_i32 : i32
18+
%6 = tt.make_tensor_ptr %arg1, [%c32_i64], [%c1_i64], [%5] {order = array<i32: 0>, tt.divisibility = dense<16> : tensor<1xi32>} : <tensor<32xbf16, #blocked1>>
19+
tt.store %6, %4 : !tt.ptr<tensor<32xbf16, #blocked1>>
20+
tt.return
21+
}
22+
}
23+
// CHECK-DAG: [[C4V:%.+]] = llvm.mlir.constant(dense<4> : vector<4xi32>) : vector<4xi32>
24+
// CHECK-DAG: [[C15V:%.+]] = llvm.mlir.constant(dense<252645135> : vector<4xi32>) : vector<4xi32>
25+
// CHECK-DAG: [[TABLE:%.+]] = llvm.mlir.constant(dense<[0.000000e+00, 5.000000e-01, 1.000000e+00, 1.500000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 6.000000e+00, -0.000000e+00, -5.000000e-01, -1.000000e+00, -1.500000e+00, -2.000000e+00, -3.000000e+00, -4.000000e+00, -6.000000e+00]> : vector<16xbf16>) : vector<16xbf16>
26+
// CHECK-DAG: [[C3:%.+]] = llvm.mlir.constant(3 : i32) : i32
27+
// CHECK: [[I32V:%.+]] = llvm.load {{.+}} : !llvm.ptr<1> -> vector<4xi32>
28+
// CHECK: [[IDXV0I32:%.+]] = llvm.and [[I32V]], [[C15V]] : vector<4xi32>
29+
// CHECK: [[LSHR:%.+]] = llvm.lshr [[I32V]], [[C4V]] : vector<4xi32>
30+
// CHECK: [[IDXV1I32:%.+]] = llvm.and [[LSHR]], [[C15V]] : vector<4xi32>
31+
// CHECK: [[IDX0I32:%.+]] = llvm.extractelement [[IDXV0I32]][[[C3]] : i32] : vector<4xi32>
32+
// CHECK: [[IDX1I32:%.+]] = llvm.extractelement [[IDXV1I32]][[[C3]] : i32] : vector<4xi32>
33+
// CHECK: [[IDXV0I8:%.+]] = llvm.bitcast [[IDX0I32]] : i32 to vector<4xi8>
34+
// CHECK: [[IDXV1I8:%.+]] = llvm.bitcast [[IDX1I32]] : i32 to vector<4xi8>
35+
// CHECK: [[IDX0I8:%.+]] = llvm.extractelement [[IDXV0I8]][[[C3]] : i32] : vector<4xi8>
36+
// CHECK: [[V0:%.+]] = llvm.extractelement [[TABLE]][[[IDX0I8]] : i8] : vector<16xbf16>
37+
// CHECK: [[IDX1I8:%.+]] = llvm.extractelement [[IDXV1I8]][[[C3]] : i32] : vector<4xi8>
38+
// CHECK: [[V1:%.+]] = llvm.extractelement [[TABLE]][[[IDX1I8]] : i8] : vector<16xbf16>
39+
40+
// -----
41+
42+
#blocked = #ttg.blocked<{sizePerThread = [16], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
43+
#blocked1 = #ttg.blocked<{sizePerThread = [32], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
44+
module attributes {triton_intel_gpu.support_bf16_conversion, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block, triton_intel_gpu.target_arch = "spir16", ttg.global_scratch_memory_alignment = 1 : i32, ttg.global_scratch_memory_size = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 16384 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 32 : i32} {
45+
tt.func public @fp4_to_bf16_kernel(%arg0: !tt.ptr<i8>, %arg1: !tt.ptr<bf16>) {
46+
%c1_i64 = arith.constant 1 : i64
47+
%c16_i32 = arith.constant 16 : i32
48+
%c16_i64 = arith.constant 16 : i64
49+
%c32_i32 = arith.constant 32 : i32
50+
%c32_i64 = arith.constant 32 : i64
51+
%0 = tt.get_program_id x : i32
52+
%1 = arith.muli %0, %c16_i32 : i32
53+
%2 = tt.make_tensor_ptr %arg0, [%c16_i64], [%c1_i64], [%1] {order = array<i32: 0>, tt.divisibility = dense<4> : tensor<1xi32>} : <tensor<16xi8, #blocked>>
54+
%3 = tt.load %2 : !tt.ptr<tensor<16xi8, #blocked>>
55+
%4 = ttg.fp4_to_fp %3 {axis = 0 : i32} : tensor<16xi8, #blocked> -> tensor<32xbf16, #blocked1>
56+
%5 = arith.muli %0, %c32_i32 : i32
57+
%6 = tt.make_tensor_ptr %arg1, [%c32_i64], [%c1_i64], [%5] {order = array<i32: 0>, tt.divisibility = dense<4> : tensor<1xi32>} : <tensor<32xbf16, #blocked1>>
58+
tt.store %6, %4 : !tt.ptr<tensor<32xbf16, #blocked1>>
59+
tt.return
60+
}
61+
}
62+
// CHECK-DAG: [[C4V:%.+]] = llvm.mlir.constant(dense<4> : vector<4xi8>) : vector<4xi8>
63+
// CHECK-DAG: [[C15V:%.+]] = llvm.mlir.constant(dense<15> : vector<4xi8>) : vector<4xi8>
64+
// CHECK-DAG: [[TABLE:%.+]] = llvm.mlir.constant(dense<[0.000000e+00, 5.000000e-01, 1.000000e+00, 1.500000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 6.000000e+00, -0.000000e+00, -5.000000e-01, -1.000000e+00, -1.500000e+00, -2.000000e+00, -3.000000e+00, -4.000000e+00, -6.000000e+00]> : vector<16xbf16>) : vector<16xbf16>
65+
// CHECK-DAG: [[C3:%.+]] = llvm.mlir.constant(3 : i32) : i32
66+
// CHECK: [[I32:%.+]] = llvm.load {{.+}} {alignment = 4 : i64} : !llvm.ptr<1> -> i32
67+
// CHECK: [[IDXVI8:%.+]] = llvm.bitcast [[I32]] : i32 to vector<4xi8>
68+
// CHECK: [[IDXV0I8:%.+]] = llvm.and [[IDXVI8]], [[C15V]] : vector<4xi8>
69+
// CHECK: [[IDXV1I8:%.+]] = llvm.lshr [[IDXVI8]], [[C4V]] : vector<4xi8>
70+
// CHECK: [[IDX0I8:%.+]] = llvm.extractelement [[IDXV0I8]][[[C3]] : i32] : vector<4xi8>
71+
// CHECK: [[V0:%.+]] = llvm.extractelement [[TABLE]][[[IDX0I8]] : i8] : vector<16xbf16>
72+
// CHECK: [[IDX1I8:%.+]] = llvm.extractelement [[IDXV1I8]][[[C3]] : i32] : vector<4xi8>
73+
// CHECK: [[V1:%.+]] = llvm.extractelement [[TABLE]][[[IDX1I8]] : i8] : vector<16xbf16>
74+
75+
// -----
76+
77+
#blocked = #ttg.blocked<{sizePerThread = [16], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
78+
#blocked1 = #ttg.blocked<{sizePerThread = [32], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
79+
module attributes {triton_intel_gpu.support_bf16_conversion, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block, triton_intel_gpu.target_arch = "spir16", ttg.global_scratch_memory_alignment = 1 : i32, ttg.global_scratch_memory_size = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 16384 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 32 : i32} {
80+
tt.func public @fp4_to_bf16_kernel(%arg0: !tt.ptr<i8>, %arg1: !tt.ptr<bf16>) {
81+
%c1_i64 = arith.constant 1 : i64
82+
%c16_i32 = arith.constant 16 : i32
83+
%c16_i64 = arith.constant 16 : i64
84+
%c32_i32 = arith.constant 32 : i32
85+
%c32_i64 = arith.constant 32 : i64
86+
%0 = tt.get_program_id x : i32
87+
%1 = arith.muli %0, %c16_i32 : i32
88+
%2 = tt.make_tensor_ptr %arg0, [%c16_i64], [%c1_i64], [%1] {order = array<i32: 0>, tt.divisibility = dense<1> : tensor<1xi32>} : <tensor<16xi8, #blocked>>
89+
%3 = tt.load %2 : !tt.ptr<tensor<16xi8, #blocked>>
90+
%4 = ttg.fp4_to_fp %3 {axis = 0 : i32} : tensor<16xi8, #blocked> -> tensor<32xbf16, #blocked1>
91+
%5 = arith.muli %0, %c32_i32 : i32
92+
%6 = tt.make_tensor_ptr %arg1, [%c32_i64], [%c1_i64], [%5] {order = array<i32: 0>, tt.divisibility = dense<1> : tensor<1xi32>} : <tensor<32xbf16, #blocked1>>
93+
tt.store %6, %4 : !tt.ptr<tensor<32xbf16, #blocked1>>
94+
tt.return
95+
}
96+
}
97+
// CHECK-DAG: [[C4:%.+]] = llvm.mlir.constant(4 : i8) : i8
98+
// CHECK-DAG: [[C15:%.+]] = llvm.mlir.constant(15 : i8) : i8
99+
// CHECK-DAG: [[TABLE:%.+]] = llvm.mlir.constant(dense<[0.000000e+00, 5.000000e-01, 1.000000e+00, 1.500000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 6.000000e+00, -0.000000e+00, -5.000000e-01, -1.000000e+00, -1.500000e+00, -2.000000e+00, -3.000000e+00, -4.000000e+00, -6.000000e+00]> : vector<16xbf16>) : vector<16xbf16>
100+
// CHECK-DAG: [[C3:%.+]] = llvm.mlir.constant(3 : i32) : i32
101+
// CHECK-DAG: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
102+
// CHECK-COUNT-16: [[I8:%.+]] = llvm.load {{.+}} {alignment = 1 : i64} : !llvm.ptr<1> -> i8
103+
// CHECK: [[I8V:%.+]] = llvm.bitcast [[I8]] : i8 to vector<1xi8>
104+
// CHECK: [[I8:%.+]] = llvm.extractelement [[I8V]][[[C0]] : i32] : vector<1xi8>
105+
// CHECK: [[IDX0I8:%.+]] = llvm.and [[I8]], [[C15]] : i8
106+
// CHECK: [[IDX1I8:%.+]] = llvm.lshr [[I8]], [[C4]] : i8
107+
// CHECK: [[V0:%.+]] = llvm.extractelement [[TABLE]][[[IDX0I8]] : i8] : vector<16xbf16>
108+
// CHECK: [[V1:%.+]] = llvm.extractelement [[TABLE]][[[IDX1I8]] : i8] : vector<16xbf16>

0 commit comments

Comments
 (0)