|
1 |
| -// RUN: env TRITON_INTEL_DECOMPOSE_SCALED_BLOCKED=1 triton-opt %s -split-input-file --tritonintelgpu-accelerate-matmul | FileCheck %s |
| 1 | +// RUN: env TRITON_INTEL_ENABLE_DPAS_FOR_WARP_SIZE_32=1 TRITON_INTEL_DECOMPOSE_SCALED_BLOCKED=1 triton-opt %s -split-input-file --tritonintelgpu-accelerate-matmul | FileCheck %s |
2 | 2 |
|
3 | 3 | // CHECK: #[[$DPAS:.+]] = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 1], repCluster = [4, 1], A = [32, 16], B = [16, 16], C = [32, 16]}>
|
4 | 4 | // CHECK: #[[$DPAS_1:.+]] = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 1], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
|
@@ -368,3 +368,22 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
|
368 | 368 | tt.return
|
369 | 369 | }
|
370 | 370 | }
|
| 371 | + |
| 372 | +// ----- |
| 373 | + |
| 374 | +// CHECK: #[[$DPAS:.+]] = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 32, warpsPerCTA = [1, 1], repCluster = [4, 1], A = [32, 8], B = [8, 16], C = [32, 16]}> |
| 375 | +#blocked = #ttg.blocked<{sizePerThread = [4, 2], threadsPerWarp = [1, 32], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> |
| 376 | +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32, "ttig.min_sg_size" = 16 : i32, "ttig.support_dpas"} { |
| 377 | + // CHECK-LABEL: dpas_sub_group_size_32 |
| 378 | + tt.func @dpas_sub_group_size_32(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}) { |
| 379 | + %zero_f32 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #blocked> |
| 380 | + %a = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> |
| 381 | + %b = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> |
| 382 | + |
| 383 | + // CHECK: tt.dot {{.*}}, {{.*}}, {{.*}}, inputPrecision = tf32 : tensor<128x128xf32, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>> * tensor<128x16xf32, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 1}>> -> tensor<128x16xf32, #[[$DPAS]]> |
| 384 | + %result = tt.dot %a, %b, %zero_f32, inputPrecision = tf32 : tensor<128x128xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x16xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x16xf32, #blocked> |
| 385 | + %result_ptr = tt.splat %arg0 : !tt.ptr<f32> -> tensor<128x16x!tt.ptr<f32>, #blocked> |
| 386 | + tt.store %result_ptr, %result : tensor<128x16x!tt.ptr<f32>, #blocked> |
| 387 | + tt.return |
| 388 | + } |
| 389 | +} |
0 commit comments