|
1 |
| -// RUN: triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul="arch-generation-name=gfx942 matrix-instruction-size=0" | FileCheck %s --check-prefixes MFMA0,CHECK |
2 |
| -// RUN: triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul="arch-generation-name=gfx942 matrix-instruction-size=16" | FileCheck %s --check-prefixes MFMA16,CHECK |
| 1 | +// RUN: triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul="arch-generation-name=gfx942 matrix-instruction-size=0" --verify-diagnostics | FileCheck %s --check-prefixes MFMA0,CHECK |
| 2 | +// RUN: triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul="arch-generation-name=gfx942 matrix-instruction-size=16" --verify-diagnostics | FileCheck %s --check-prefixes MFMA16,CHECK |
3 | 3 |
|
4 | 4 | #blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 8], warpsPerCTA = [2, 4], order = [1, 0]}>
|
5 |
| -// CHECK-LABEL: mfma_dot_fp8e5m2 |
| 5 | +// CHECK-LABEL: mfma_dot_fp8e5m2_fp8e4m3fn |
6 | 6 | module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
|
7 |
| - tt.func public @mfma_dot_fp8e5m2( |
| 7 | + tt.func public @mfma_dot_fp8e5m2_fp8e4m3fn( |
8 | 8 | %arg0: tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>,
|
9 |
| - %arg1: tensor<64x256xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, |
| 9 | + %arg1: tensor<64x256xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, |
10 | 10 | %arg2: tensor<128x256x!tt.ptr<f32>, #blocked>) {
|
11 | 11 | %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #blocked>
|
12 | 12 | // CHECK: %[[A0:.+]] = ttg.convert_layout %arg0 : {{.*}} -> tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
|
13 | 13 | // CHECK: %[[A1:.+]] = tt.fp_to_fp %[[A0]] : {{.*}} -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
|
| 14 | + // CHECK: %[[B0:.+]] = ttg.convert_layout %arg1 : {{.*}} -> tensor<64x256xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> |
| 15 | + // CHECK: %[[B1:.+]] = tt.fp_to_fp %[[B0]] : tensor<64x256xf8E4M3FN, {{.*}} -> tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> |
| 16 | + // CHECK: tt.dot %[[A1]], %[[B1]] |
| 17 | + // expected-remark @+2 {{missing native support for fp8 variant on current architecture; emulated with fp16 so low performance}} |
| 18 | + // expected-remark @+1 {{for gfx942 please use native supported fp8 variants}} |
| 19 | + %1 = tt.dot %arg0, %arg1, %cst : tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x256xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x256xf32, #blocked> |
| 20 | + tt.store %arg2, %1 : tensor<128x256x!tt.ptr<f32>, #blocked> |
| 21 | + tt.return |
| 22 | + } |
| 23 | +} |
| 24 | + |
| 25 | +// ----- |
| 26 | + |
| 27 | +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 8], warpsPerCTA = [2, 4], order = [1, 0]}> |
| 28 | +// CHECK-LABEL: mfma_dot_fp8e4m3fn_fp8e5m2 |
| 29 | +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { |
| 30 | + tt.func public @mfma_dot_fp8e4m3fn_fp8e5m2( |
| 31 | + %arg0: tensor<128x64xf8E4M3FN, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, |
| 32 | + %arg1: tensor<64x256xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, |
| 33 | + %arg2: tensor<128x256x!tt.ptr<f32>, #blocked>) { |
| 34 | + %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #blocked> |
| 35 | + // CHECK: %[[A0:.+]] = ttg.convert_layout %arg0 : {{.*}} -> tensor<128x64xf8E4M3FN, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> |
| 36 | + // CHECK: %[[A1:.+]] = tt.fp_to_fp %[[A0]] : {{.*}} -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> |
14 | 37 | // CHECK: %[[B0:.+]] = ttg.convert_layout %arg1 : {{.*}} -> tensor<64x256xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
|
15 | 38 | // CHECK: %[[B1:.+]] = tt.fp_to_fp %[[B0]] : tensor<64x256xf8E5M2, {{.*}} -> tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
|
16 | 39 | // CHECK: tt.dot %[[A1]], %[[B1]]
|
17 |
| - %1 = tt.dot %arg0, %arg1, %cst : tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x256xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x256xf32, #blocked> |
| 40 | + // expected-remark @+2 {{missing native support for fp8 variant on current architecture; emulated with fp16 so low performance}} |
| 41 | + // expected-remark @+1 {{for gfx942 please use native supported fp8 variants}} |
| 42 | + %1 = tt.dot %arg0, %arg1, %cst : tensor<128x64xf8E4M3FN, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x256xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x256xf32, #blocked> |
18 | 43 | tt.store %arg2, %1 : tensor<128x256x!tt.ptr<f32>, #blocked>
|
19 | 44 | tt.return
|
20 | 45 | }
|
|
0 commit comments