|
1 | | -// RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx942 | FileCheck --check-prefix=GFX942 %s |
2 | | -// RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx950 | FileCheck --check-prefix=GFX950 %s |
| 1 | +// RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx942 | FileCheck --check-prefixes=COMMON,GFX942 %s |
| 2 | +// RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx950 | FileCheck --check-prefixes=COMMON,GFX950 %s |
3 | 3 |
|
4 | 4 | // CHECK-LABEL: f16_to_f32 |
5 | 5 | #blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> |
@@ -32,15 +32,30 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr |
32 | 32 | // GFX942-COUNT-8: llvm.fptrunc %{{.+}} : f32 to f16 |
33 | 33 | // GFX950-COUNT-4: llvm.fptrunc %{{.+}} : vector<2xf32> to vector<2xf16> |
34 | 34 | %0 = tt.fp_to_fp %arg0, rounding = rtne : tensor<8x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> |
35 | | - // GFX942-COUNT-4: rocdl.cvt.pkrtz |
36 | | - // GFX950-COUNT-4: rocdl.cvt.pkrtz |
| 35 | + // COMMON-COUNT-4: rocdl.cvt.pkrtz |
37 | 36 | %1 = tt.fp_to_fp %arg0, rounding = rtz : tensor<8x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> |
38 | 37 | tt.return |
39 | 38 | } |
40 | 39 | } |
41 | 40 |
|
42 | 41 | // ----- |
43 | 42 |
|
| 43 | +// CHECK-LABEL: f32_to_f16_single_value |
| 44 | +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}> |
| 45 | +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { |
| 46 | + tt.func @f32_to_f16_single_value(%arg0: tensor<1x128xf32, #blocked>) { |
| 47 | + // COMMON: llvm.fptrunc %{{.+}} : f32 to f16 |
| 48 | + // COMMON-NOT: llvm.fptrunc |
| 49 | + %0 = tt.fp_to_fp %arg0, rounding = rtne : tensor<1x128xf32, #blocked> -> tensor<1x128xf16, #blocked> |
| 50 | + // COMMON: rocdl.cvt.pkrtz |
| 51 | + // COMMON-NOT: rocdl.cvt.pkrtz |
| 52 | + %1 = tt.fp_to_fp %arg0, rounding = rtz : tensor<1x128xf32, #blocked> -> tensor<1x128xf16, #blocked> |
| 53 | + tt.return |
| 54 | + } |
| 55 | +} |
| 56 | + |
| 57 | +// ----- |
| 58 | + |
44 | 59 | // CHECK-LABEL: downcast_to_f8 |
45 | 60 | #blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> |
46 | 61 | module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { |
|
0 commit comments