Skip to content

Commit 9c6817d

Browse files
authored
[tt.dot_scaled]: Support for scale on operand "B" (#2910)
This PR extends the support for `tt.dot_scaled` decomposition pass and allows the "B" operand to have a scaling factor (as well as the "A" operand). --------- Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent 680965c commit 9c6817d

File tree

3 files changed

+197
-87
lines changed

3 files changed

+197
-87
lines changed

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,9 @@ LogicalResult UpcastMXFPOp::inferReturnTypes(
122122
newShape[kIdx] *= 2;
123123
Type elemType = FloatType::getBF16(ctx);
124124

125-
// Note: For Intel the dot operands layout's kWidth parameter must
126-
// match the parent's DPAS layout opsPerChannel so we need to materialize
127-
// a new DPAS layout.
125+
// Note: For Intel the dot operands layout's kWidth parameter must match
126+
// the parent's DPAS layout opsPerChannel so we need to materialize a new
127+
// DPAS layout.
128128
Attribute newVEncoding;
129129
if (auto dpasEncoding =
130130
dyn_cast<intel::DpasEncodingAttr>(oldEncoding.getParent())) {
@@ -135,8 +135,7 @@ LogicalResult UpcastMXFPOp::inferReturnTypes(
135135
dpasEncoding.getWarpsPerCTA(), dpasEncoding.getRepCluster(),
136136
dpasEncoding.getSubGroupSize());
137137
newVEncoding = DotOperandEncodingAttr::get(
138-
ctx, oldEncoding.getOpIdx(), newDpasEncoding,
139-
newDpasEncoding.getOpsPerChannel());
138+
ctx, opIdx, newDpasEncoding, newDpasEncoding.getOpsPerChannel());
140139
} else {
141140
// Figure out the K dimension for the input A/B, given that the return
142141
// type is upcasted A/B type so we need to update the proper dim size.

test/TritonIntelGPU/accelerate-matmul-pvc.mlir

Lines changed: 71 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -208,25 +208,87 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
208208
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
209209
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}>
210210

211-
module attributes {"ttg.target" = "xpu", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32, "triton_intel_gpu.min_sg_size" = 16 : i32, "triton_intel_gpu.support_dpas"} {
212-
// CHECK: [[BLOCKED:#.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}>
213-
// CHECK: [[BLOCKED1:#.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
214-
// CHECK: [[BLOCKED2:#.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
215-
// CHECK: [[DPAS:#.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 1], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
216-
// CHECK: [[DPAS1:#.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 4, threadsPerWarp = 16, warpsPerCTA = [4, 1], repCluster = [1, 1], A = [8, 32], B = [32, 16], C = [8, 16]}>
217-
// CHECK: dot_scaled
211+
module attributes {"ttg.target" = "xpu", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32, "triton_intel_gpu.min_sg_size" = 16 : i32} {
212+
// CHECK-DAG: [[BLOCKED:#.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}>
213+
// CHECK-DAG: [[BLOCKED1:#.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
214+
// CHECK-DAG: [[BLOCKED2:#.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
215+
// CHECK-DAG: [[DPAS:#.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 1], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
216+
// CHECK-DAG: [[DPAS1:#.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 4, threadsPerWarp = 16, warpsPerCTA = [4, 1], repCluster = [1, 1], A = [8, 32], B = [32, 16], C = [8, 16]}>
217+
218+
// CHECK: tt.func @dot_scaled([[ARG0:%.*]]: tensor<128x32xi8, [[BLOCKED]]>, [[ARG1:%.*]]: tensor<128x2xi8, [[BLOCKED1]]>, [[ARG2:%.*]]: tensor<64x128xbf16, [[BLOCKED2]]>) -> tensor<128x128xf32, [[BLOCKED2]]> {
218219
tt.func @dot_scaled(%a: tensor<128x32xi8, #blocked2>, %scale: tensor<128x2xi8, #blocked1>, %b: tensor<64x128xbf16, #blocked>) -> tensor<128x128xf32, #blocked> {
220+
// CHECK: [[CST:%.*]] = arith.constant dense<0.000000e+00> : tensor<128x128xf32, [[BLOCKED2]]>
221+
// CHECK: [[C:%.*]] = ttg.convert_layout [[CST]] : tensor<128x128xf32, [[BLOCKED2]]> -> tensor<128x128xf32, [[DPAS]]>
222+
// CHECK: [[CVT_ARG0:%.*]] = ttg.convert_layout [[ARG0]] : tensor<128x32xi8, [[BLOCKED]]> -> tensor<128x32xi8, #ttg.dot_op<{opIdx = 0, parent = [[DPAS1]], kWidth = 4}>>
223+
// CHECK: [[CVT_ARG1:%.*]] = ttg.convert_layout [[ARG1]] : tensor<128x2xi8, [[BLOCKED1]]> -> tensor<128x2xi8, [[BLOCKED1]]>
224+
// CHECK: [[A:%.*]] = ttg.upcast_mxfp [[CVT_ARG0]], [[CVT_ARG1]] fp_type = e2m1 : tensor<128x32xi8, #ttg.dot_op<{opIdx = 0, parent = [[DPAS1]], kWidth = 4}>>, tensor<128x2xi8, [[BLOCKED1]]> -> tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>>
225+
// CHECK: [[B:%.*]] = ttg.convert_layout [[ARG2]] : tensor<64x128xbf16, [[BLOCKED2]]> -> tensor<64x128xbf16, #ttg.dot_op<{opIdx = 1, parent = [[DPAS]], kWidth = 2}>>
226+
// CHECK: [[D:%.*]] = tt.dot [[A]], [[B]], [[C]] : tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>> * tensor<64x128xbf16, #ttg.dot_op<{opIdx = 1, parent = [[DPAS]], kWidth = 2}>> -> tensor<128x128xf32, [[DPAS]]>
227+
// CHECK: [[RES:%.*]] = ttg.convert_layout {{.*}} : tensor<128x128xf32, [[DPAS]]> -> tensor<128x128xf32, [[BLOCKED2]]>
228+
229+
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked>
230+
%dot_res1 = tt.dot_scaled %a scale %scale, %b, %cst lhs = e2m1 rhs = bf16 : tensor<128x32xi8, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<64x128xbf16, #blocked> -> tensor<128x128xf32, #blocked>
231+
tt.return %dot_res1 : tensor<128x128xf32, #blocked>
232+
}
233+
234+
// CHECK: tt.func @dot_scaled_fp8([[ARG0:%.*]]: tensor<128x32xi8, [[BLOCKED]]>, [[ARG1:%.*]]: tensor<128x2xi8, [[BLOCKED1]]>, [[ARG2:%.*]]: tensor<64x128xf8E4M3FN, [[BLOCKED2]]>) -> tensor<128x128xf32, [[BLOCKED2]]> {
235+
tt.func @dot_scaled_fp8(%a: tensor<128x32xi8, #blocked2>, %scale: tensor<128x2xi8, #blocked1>, %b: tensor<64x128xf8E4M3FN, #blocked>) -> tensor<128x128xf32, #blocked> {
219236
// CHECK: [[CST:%.*]] = arith.constant dense<0.000000e+00> : tensor<128x128xf32, [[BLOCKED2]]>
220237
// CHECK: [[C:%.*]] = ttg.convert_layout [[CST]] : tensor<128x128xf32, [[BLOCKED2]]> -> tensor<128x128xf32, [[DPAS]]>
221238
// CHECK: [[CVT_ARG0:%.*]] = ttg.convert_layout %arg0 : tensor<128x32xi8, [[BLOCKED]]> -> tensor<128x32xi8, #ttg.dot_op<{opIdx = 0, parent = [[DPAS1]], kWidth = 4}>>
222239
// CHECK: [[CVT_ARG1:%.*]] = ttg.convert_layout %arg1 : tensor<128x2xi8, [[BLOCKED1]]> -> tensor<128x2xi8, [[BLOCKED1]]>
223240
// CHECK: [[A:%.*]] = ttg.upcast_mxfp [[CVT_ARG0]], [[CVT_ARG1]] fp_type = e2m1 : tensor<128x32xi8, #ttg.dot_op<{opIdx = 0, parent = [[DPAS1]], kWidth = 4}>>, tensor<128x2xi8, [[BLOCKED1]]> -> tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>>
224-
// CHECK: [[B:%.*]] = ttg.convert_layout %arg2 : tensor<64x128xbf16, [[BLOCKED2]]> -> tensor<64x128xbf16, #ttg.dot_op<{opIdx = 1, parent = [[DPAS]], kWidth = 2}>>
241+
// CHECK: [[CVT_ARG2:%.*]] = ttg.convert_layout [[ARG2]] : tensor<64x128xf8E4M3FN, [[BLOCKED2]]> -> tensor<64x128xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = [[DPAS]], kWidth = 2}>>
242+
// CHECK: [[B:%.*]] = tt.fp_to_fp [[CVT_ARG2]] : tensor<64x128xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = [[DPAS]], kWidth = 2}>> -> tensor<64x128xbf16, #ttg.dot_op<{opIdx = 1, parent = [[DPAS]], kWidth = 2}>>
225243
// CHECK: [[D:%.*]] = tt.dot [[A]], [[B]], [[C]] : tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>> * tensor<64x128xbf16, #ttg.dot_op<{opIdx = 1, parent = [[DPAS]], kWidth = 2}>> -> tensor<128x128xf32, [[DPAS]]>
226244
// CHECK: [[RES:%.*]] = ttg.convert_layout {{.*}} : tensor<128x128xf32, [[DPAS]]> -> tensor<128x128xf32, [[BLOCKED2]]>
227245

228246
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked>
229-
%result = tt.dot_scaled %a scale %scale, %b, %cst lhs = e2m1 rhs = bf16 : tensor<128x32xi8, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<64x128xbf16, #blocked> -> tensor<128x128xf32, #blocked>
247+
%result = tt.dot_scaled %a scale %scale, %b, %cst lhs = e2m1 rhs = e4m3 : tensor<128x32xi8, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<64x128xf8E4M3FN, #blocked> -> tensor<128x128xf32, #blocked>
230248
tt.return %result : tensor<128x128xf32, #blocked>
231249
}
232250
}
251+
252+
// -----
253+
254+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}>
255+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
256+
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
257+
#blocked3 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
258+
259+
module attributes {ttg.target = "xpu", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32, "triton_intel_gpu.min_sg_size" = 16 : i32} {
260+
// CHECK-DAG: [[BLOCKED:#.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}>
261+
// CHECK-DAG: [[BLOCKED1:#.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
262+
// CHECK-DAG: [[BLOCKED2:#.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
263+
// CHECK-DAG: [[BLOCKED3:#.+]] = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
264+
// CHECK-DAG: [[DPAS:#.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 1], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
265+
// CHECK-DAG: [[DPAS1:#.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 4, threadsPerWarp = 16, warpsPerCTA = [4, 1], repCluster = [1, 1], A = [8, 32], B = [32, 16], C = [8, 16]}>
266+
267+
// CHECK: tt.func @dot_scale_transpose([[ARG0:%.*]]: tensor<128x64xf8E4M3FN, [[BLOCKED]]>, [[ARG1:%.*]]: tensor<32x32xi8, [[BLOCKED1]]>, [[ARG2:%.*]]: tensor<32x2xi8, [[BLOCKED2]]>, %arg3: tensor<128x32x!tt.ptr<bf16>, [[BLOCKED3]]>) {
268+
tt.func @dot_scale_transpose(%a: tensor<128x64xf8E4M3FN, #blocked>, %b: tensor<32x32xi8, #blocked1>, %scale: tensor<32x2xi8, #blocked2>, %d: tensor<128x32x!tt.ptr<bf16>, #blocked3>) {
269+
%cst = arith.constant dense<0.000000e+00> : tensor<128x32xf32, #blocked1>
270+
%c1_i32 = arith.constant 1 : i32
271+
%c100_i32 = arith.constant 100 : i32
272+
%c0_i32 = arith.constant 0 : i32
273+
%cst_0 = arith.constant dense<32> : tensor<32x1xi32, #blocked3>
274+
%cst_1 = arith.constant dense<2> : tensor<32x1xi32, #blocked2>
275+
// CHECK: scf.for {{.*}} iter_args([[ARG5:%.*]] = %cst)
276+
%0 = scf.for %arg4 = %c0_i32 to %c100_i32 step %c1_i32 iter_args(%arg5 = %cst) -> (tensor<128x32xf32, #blocked1>) : i32 {
277+
// CHECK: [[C:%.*]] = ttg.convert_layout [[ARG5]] : tensor<128x32xf32, [[BLOCKED1]]> -> tensor<128x32xf32, [[DPAS]]>
278+
// CHECK: [[CVT_ARG1:%.*]] = ttg.convert_layout [[ARG1]] : tensor<32x32xi8, [[BLOCKED1]]> -> tensor<32x32xi8, #ttg.dot_op<{opIdx = 1, parent = [[DPAS1]], kWidth = 4}>>
279+
// CHECK: [[CVT_ARG2:%.*]] = ttg.convert_layout [[ARG2]] : tensor<32x2xi8, [[BLOCKED2]]> -> tensor<32x2xi8, [[BLOCKED2]]>
280+
// CHECK: [[B:%.*]] = ttg.upcast_mxfp [[CVT_ARG1]], [[CVT_ARG2]] fp_type = e2m1 : tensor<32x32xi8, #ttg.dot_op<{opIdx = 1, parent = [[DPAS1]], kWidth = 4}>>, tensor<32x2xi8, [[BLOCKED2]]> -> tensor<64x32xbf16, #ttg.dot_op<{opIdx = 1, parent = [[DPAS]], kWidth = 2}>>
281+
// CHECK: [[CVT_ARG0:%.*]] = ttg.convert_layout [[ARG0]] : tensor<128x64xf8E4M3FN, [[BLOCKED]]> -> tensor<128x64xf8E4M3FN, #ttg.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>>
282+
// CHECK: [[A:%.*]] = tt.fp_to_fp [[CVT_ARG0]] : tensor<128x64xf8E4M3FN, #ttg.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>> -> tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>>
283+
// CHECK: [[D:%.*]] = tt.dot [[A]], [[B]], [[C]] : tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>> * tensor<64x32xbf16, #ttg.dot_op<{opIdx = 1, parent = [[DPAS]], kWidth = 2}>> -> tensor<128x32xf32, [[DPAS]]>
284+
// CHECK: [[RES:%.*]] = ttg.convert_layout [[D]] : tensor<128x32xf32, [[DPAS]]> -> tensor<128x32xf32, [[BLOCKED1]]>
285+
// CHECK: scf.yield [[RES]] : tensor<128x32xf32, [[BLOCKED1]]>
286+
%3 = tt.dot_scaled %a, %b scale %scale, %arg5 lhs = e4m3 rhs = e2m1 : tensor<128x64xf8E4M3FN, #blocked> * tensor<32x32xi8, #blocked1>, tensor<32x2xi8, #blocked2> -> tensor<128x32xf32, #blocked1>
287+
scf.yield %3 : tensor<128x32xf32, #blocked1>
288+
}
289+
%1 = arith.truncf %0 : tensor<128x32xf32, #blocked1> to tensor<128x32xbf16, #blocked1>
290+
%2 = ttg.convert_layout %1 : tensor<128x32xbf16, #blocked1> -> tensor<128x32xbf16, #blocked3>
291+
tt.store %d, %2 : tensor<128x32x!tt.ptr<bf16>, #blocked3>
292+
tt.return
293+
}
294+
}

0 commit comments

Comments
 (0)