@@ -208,25 +208,87 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
208208#blocked1 = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [16 , 1 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
209209#blocked2 = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 16 ], warpsPerCTA = [2 , 2 ], order = [1 , 0 ]}>
210210
211- module attributes {" ttg.target" = " xpu" , " ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , " ttg.threads-per-warp" = 16 : i32 , " triton_intel_gpu.min_sg_size" = 16 : i32 , " triton_intel_gpu.support_dpas" } {
212- // CHECK: [[BLOCKED:#.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}>
213- // CHECK: [[BLOCKED1:#.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
214- // CHECK: [[BLOCKED2:#.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
215- // CHECK: [[DPAS:#.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 1], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
216- // CHECK: [[DPAS1:#.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 4, threadsPerWarp = 16, warpsPerCTA = [4, 1], repCluster = [1, 1], A = [8, 32], B = [32, 16], C = [8, 16]}>
217- // CHECK: dot_scaled
211+ module attributes {" ttg.target" = " xpu" , " ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , " ttg.threads-per-warp" = 16 : i32 , " triton_intel_gpu.min_sg_size" = 16 : i32 } {
212+ // CHECK-DAG: [[BLOCKED:#.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}>
213+ // CHECK-DAG: [[BLOCKED1:#.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
214+ // CHECK-DAG: [[BLOCKED2:#.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
215+ // CHECK-DAG: [[DPAS:#.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 1], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
216+ // CHECK-DAG: [[DPAS1:#.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 4, threadsPerWarp = 16, warpsPerCTA = [4, 1], repCluster = [1, 1], A = [8, 32], B = [32, 16], C = [8, 16]}>
217+
218+ // CHECK: tt.func @dot_scaled([[ARG0:%.*]]: tensor<128x32xi8, [[BLOCKED]]>, [[ARG1:%.*]]: tensor<128x2xi8, [[BLOCKED1]]>, [[ARG2:%.*]]: tensor<64x128xbf16, [[BLOCKED2]]>) -> tensor<128x128xf32, [[BLOCKED2]]> {
218219 tt.func @dot_scaled (%a: tensor <128 x32 xi8 , #blocked2 >, %scale: tensor <128 x2 xi8 , #blocked1 >, %b: tensor <64 x128 xbf16 , #blocked >) -> tensor <128 x128 xf32 , #blocked > {
220+ // CHECK: [[CST:%.*]] = arith.constant dense<0.000000e+00> : tensor<128x128xf32, [[BLOCKED2]]>
221+ // CHECK: [[C:%.*]] = ttg.convert_layout [[CST]] : tensor<128x128xf32, [[BLOCKED2]]> -> tensor<128x128xf32, [[DPAS]]>
222+ // CHECK: [[CVT_ARG0:%.*]] = ttg.convert_layout [[ARG0]] : tensor<128x32xi8, [[BLOCKED]]> -> tensor<128x32xi8, #ttg.dot_op<{opIdx = 0, parent = [[DPAS1]], kWidth = 4}>>
223+ // CHECK: [[CVT_ARG1:%.*]] = ttg.convert_layout [[ARG1]] : tensor<128x2xi8, [[BLOCKED1]]> -> tensor<128x2xi8, [[BLOCKED1]]>
224+ // CHECK: [[A:%.*]] = ttg.upcast_mxfp [[CVT_ARG0]], [[CVT_ARG1]] fp_type = e2m1 : tensor<128x32xi8, #ttg.dot_op<{opIdx = 0, parent = [[DPAS1]], kWidth = 4}>>, tensor<128x2xi8, [[BLOCKED1]]> -> tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>>
225+ // CHECK: [[B:%.*]] = ttg.convert_layout [[ARG2]] : tensor<64x128xbf16, [[BLOCKED2]]> -> tensor<64x128xbf16, #ttg.dot_op<{opIdx = 1, parent = [[DPAS]], kWidth = 2}>>
226+ // CHECK: [[D:%.*]] = tt.dot [[A]], [[B]], [[C]] : tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>> * tensor<64x128xbf16, #ttg.dot_op<{opIdx = 1, parent = [[DPAS]], kWidth = 2}>> -> tensor<128x128xf32, [[DPAS]]>
227+ // CHECK: [[RES:%.*]] = ttg.convert_layout {{.*}} : tensor<128x128xf32, [[DPAS]]> -> tensor<128x128xf32, [[BLOCKED2]]>
228+
229+ %cst = arith.constant dense <0.000000e+00 > : tensor <128 x128 xf32 , #blocked >
230+ %dot_res1 = tt.dot_scaled %a scale %scale , %b , %cst lhs = e2m1 rhs = bf16 : tensor <128 x32 xi8 , #blocked2 >, tensor <128 x2 xi8 , #blocked1 > * tensor <64 x128 xbf16 , #blocked > -> tensor <128 x128 xf32 , #blocked >
231+ tt.return %dot_res1 : tensor <128 x128 xf32 , #blocked >
232+ }
233+
234+ // CHECK: tt.func @dot_scaled_fp8([[ARG0:%.*]]: tensor<128x32xi8, [[BLOCKED]]>, [[ARG1:%.*]]: tensor<128x2xi8, [[BLOCKED1]]>, [[ARG2:%.*]]: tensor<64x128xf8E4M3FN, [[BLOCKED2]]>) -> tensor<128x128xf32, [[BLOCKED2]]> {
235+ tt.func @dot_scaled_fp8 (%a: tensor <128 x32 xi8 , #blocked2 >, %scale: tensor <128 x2 xi8 , #blocked1 >, %b: tensor <64 x128 xf8 E4 M3 FN, #blocked >) -> tensor <128 x128 xf32 , #blocked > {
219236 // CHECK: [[CST:%.*]] = arith.constant dense<0.000000e+00> : tensor<128x128xf32, [[BLOCKED2]]>
220237 // CHECK: [[C:%.*]] = ttg.convert_layout [[CST]] : tensor<128x128xf32, [[BLOCKED2]]> -> tensor<128x128xf32, [[DPAS]]>
221238 // CHECK: [[CVT_ARG0:%.*]] = ttg.convert_layout %arg0 : tensor<128x32xi8, [[BLOCKED]]> -> tensor<128x32xi8, #ttg.dot_op<{opIdx = 0, parent = [[DPAS1]], kWidth = 4}>>
222239 // CHECK: [[CVT_ARG1:%.*]] = ttg.convert_layout %arg1 : tensor<128x2xi8, [[BLOCKED1]]> -> tensor<128x2xi8, [[BLOCKED1]]>
223240 // CHECK: [[A:%.*]] = ttg.upcast_mxfp [[CVT_ARG0]], [[CVT_ARG1]] fp_type = e2m1 : tensor<128x32xi8, #ttg.dot_op<{opIdx = 0, parent = [[DPAS1]], kWidth = 4}>>, tensor<128x2xi8, [[BLOCKED1]]> -> tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>>
224- // CHECK: [[B:%.*]] = ttg.convert_layout %arg2 : tensor<64x128xbf16, [[BLOCKED2]]> -> tensor<64x128xbf16, #ttg.dot_op<{opIdx = 1, parent = [[DPAS]], kWidth = 2}>>
241+ // CHECK: [[CVT_ARG2:%.*]] = ttg.convert_layout [[ARG2]] : tensor<64x128xf8E4M3FN, [[BLOCKED2]]> -> tensor<64x128xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = [[DPAS]], kWidth = 2}>>
242+ // CHECK: [[B:%.*]] = tt.fp_to_fp [[CVT_ARG2]] : tensor<64x128xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = [[DPAS]], kWidth = 2}>> -> tensor<64x128xbf16, #ttg.dot_op<{opIdx = 1, parent = [[DPAS]], kWidth = 2}>>
225243 // CHECK: [[D:%.*]] = tt.dot [[A]], [[B]], [[C]] : tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>> * tensor<64x128xbf16, #ttg.dot_op<{opIdx = 1, parent = [[DPAS]], kWidth = 2}>> -> tensor<128x128xf32, [[DPAS]]>
226244 // CHECK: [[RES:%.*]] = ttg.convert_layout {{.*}} : tensor<128x128xf32, [[DPAS]]> -> tensor<128x128xf32, [[BLOCKED2]]>
227245
228246 %cst = arith.constant dense <0.000000e+00 > : tensor <128 x128 xf32 , #blocked >
229- %result = tt.dot_scaled %a scale %scale , %b , %cst lhs = e2m1 rhs = bf16 : tensor <128 x32 xi8 , #blocked2 >, tensor <128 x2 xi8 , #blocked1 > * tensor <64 x 128 x bf16 , #blocked > -> tensor <128 x128 xf32 , #blocked >
247+ %result = tt.dot_scaled %a scale %scale , %b , %cst lhs = e2m1 rhs = e4m3 : tensor <128 x32 xi8 , #blocked2 >, tensor <128 x2 xi8 , #blocked1 > * tensor <64 x 128 xf 8 E 4 M 3 FN , #blocked > -> tensor <128 x128 xf32 , #blocked >
230248 tt.return %result : tensor <128 x128 xf32 , #blocked >
231249 }
232250}
251+
252+ // -----
253+
254+ #blocked = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 16 ], warpsPerCTA = [2 , 2 ], order = [1 , 0 ]}>
255+ #blocked1 = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 16 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
256+ #blocked2 = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [16 , 1 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
257+ #blocked3 = #ttg.blocked <{sizePerThread = [1 , 8 ], threadsPerWarp = [8 , 2 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
258+
259+ module attributes {ttg.target = " xpu" , " ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , " ttg.threads-per-warp" = 16 : i32 , " triton_intel_gpu.min_sg_size" = 16 : i32 } {
260+ // CHECK-DAG: [[BLOCKED:#.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}>
261+ // CHECK-DAG: [[BLOCKED1:#.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
262+ // CHECK-DAG: [[BLOCKED2:#.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
263+ // CHECK-DAG: [[BLOCKED3:#.+]] = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
264+ // CHECK-DAG: [[DPAS:#.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 1], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
265+ // CHECK-DAG: [[DPAS1:#.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 4, threadsPerWarp = 16, warpsPerCTA = [4, 1], repCluster = [1, 1], A = [8, 32], B = [32, 16], C = [8, 16]}>
266+
267+ // CHECK: tt.func @dot_scale_transpose([[ARG0:%.*]]: tensor<128x64xf8E4M3FN, [[BLOCKED]]>, [[ARG1:%.*]]: tensor<32x32xi8, [[BLOCKED1]]>, [[ARG2:%.*]]: tensor<32x2xi8, [[BLOCKED2]]>, %arg3: tensor<128x32x!tt.ptr<bf16>, [[BLOCKED3]]>) {
268+ tt.func @dot_scale_transpose (%a: tensor <128 x64 xf8 E4 M3 FN, #blocked >, %b: tensor <32 x32 xi8 , #blocked1 >, %scale: tensor <32 x2 xi8 , #blocked2 >, %d: tensor <128 x32 x!tt.ptr <bf16 >, #blocked3 >) {
269+ %cst = arith.constant dense <0.000000e+00 > : tensor <128 x32 xf32 , #blocked1 >
270+ %c1_i32 = arith.constant 1 : i32
271+ %c100_i32 = arith.constant 100 : i32
272+ %c0_i32 = arith.constant 0 : i32
273+ %cst_0 = arith.constant dense <32 > : tensor <32 x1 xi32 , #blocked3 >
274+ %cst_1 = arith.constant dense <2 > : tensor <32 x1 xi32 , #blocked2 >
275+ // CHECK: scf.for {{.*}} iter_args([[ARG5:%.*]] = %cst)
276+ %0 = scf.for %arg4 = %c0_i32 to %c100_i32 step %c1_i32 iter_args (%arg5 = %cst ) -> (tensor <128 x32 xf32 , #blocked1 >) : i32 {
277+ // CHECK: [[C:%.*]] = ttg.convert_layout [[ARG5]] : tensor<128x32xf32, [[BLOCKED1]]> -> tensor<128x32xf32, [[DPAS]]>
278+ // CHECK: [[CVT_ARG1:%.*]] = ttg.convert_layout [[ARG1]] : tensor<32x32xi8, [[BLOCKED1]]> -> tensor<32x32xi8, #ttg.dot_op<{opIdx = 1, parent = [[DPAS1]], kWidth = 4}>>
279+ // CHECK: [[CVT_ARG2:%.*]] = ttg.convert_layout [[ARG2]] : tensor<32x2xi8, [[BLOCKED2]]> -> tensor<32x2xi8, [[BLOCKED2]]>
280+ // CHECK: [[B:%.*]] = ttg.upcast_mxfp [[CVT_ARG1]], [[CVT_ARG2]] fp_type = e2m1 : tensor<32x32xi8, #ttg.dot_op<{opIdx = 1, parent = [[DPAS1]], kWidth = 4}>>, tensor<32x2xi8, [[BLOCKED2]]> -> tensor<64x32xbf16, #ttg.dot_op<{opIdx = 1, parent = [[DPAS]], kWidth = 2}>>
281+ // CHECK: [[CVT_ARG0:%.*]] = ttg.convert_layout [[ARG0]] : tensor<128x64xf8E4M3FN, [[BLOCKED]]> -> tensor<128x64xf8E4M3FN, #ttg.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>>
282+ // CHECK: [[A:%.*]] = tt.fp_to_fp [[CVT_ARG0]] : tensor<128x64xf8E4M3FN, #ttg.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>> -> tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>>
283+ // CHECK: [[D:%.*]] = tt.dot [[A]], [[B]], [[C]] : tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>> * tensor<64x32xbf16, #ttg.dot_op<{opIdx = 1, parent = [[DPAS]], kWidth = 2}>> -> tensor<128x32xf32, [[DPAS]]>
284+ // CHECK: [[RES:%.*]] = ttg.convert_layout [[D]] : tensor<128x32xf32, [[DPAS]]> -> tensor<128x32xf32, [[BLOCKED1]]>
285+ // CHECK: scf.yield [[RES]] : tensor<128x32xf32, [[BLOCKED1]]>
286+ %3 = tt.dot_scaled %a , %b scale %scale , %arg5 lhs = e4m3 rhs = e2m1 : tensor <128 x64 xf8 E4 M3 FN, #blocked > * tensor <32 x32 xi8 , #blocked1 >, tensor <32 x2 xi8 , #blocked2 > -> tensor <128 x32 xf32 , #blocked1 >
287+ scf.yield %3 : tensor <128 x32 xf32 , #blocked1 >
288+ }
289+ %1 = arith.truncf %0 : tensor <128 x32 xf32 , #blocked1 > to tensor <128 x32 xbf16 , #blocked1 >
290+ %2 = ttg.convert_layout %1 : tensor <128 x32 xbf16 , #blocked1 > -> tensor <128 x32 xbf16 , #blocked3 >
291+ tt.store %d , %2 : tensor <128 x32 x!tt.ptr <bf16 >, #blocked3 >
292+ tt.return
293+ }
294+ }
0 commit comments