|
| 1 | +// RUN: enzymexlamlir-opt %s --enzyme-wrap="infn=add_kernel outfn= argTys=enzyme_dup,enzyme_const,enzyme_dup,enzyme_const retTys= mode=ForwardMode" | FileCheck %s |
| 2 | + |
| 3 | +module { |
| 4 | + tt.func public @add_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { |
| 5 | + %c1024_i32 = arith.constant 1024 : i32 |
| 6 | + %0 = tt.get_program_id x : i32 |
| 7 | + %1 = arith.muli %0, %c1024_i32 : i32 |
| 8 | + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> |
| 9 | + %3 = tt.splat %1 : i32 -> tensor<1024xi32> |
| 10 | + %4 = arith.addi %3, %2 : tensor<1024xi32> |
| 11 | + %5 = tt.splat %arg3 : i32 -> tensor<1024xi32> |
| 12 | + %6 = arith.cmpi slt, %4, %5 : tensor<1024xi32> |
| 13 | + %7 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> |
| 14 | + %8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> |
| 15 | + %9 = tt.load %8, %6 : tensor<1024x!tt.ptr<f32>> |
| 16 | + %10 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> |
| 17 | + %11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> |
| 18 | + %12 = tt.load %11, %6 : tensor<1024x!tt.ptr<f32>> |
| 19 | + %13 = arith.addf %9, %12 : tensor<1024xf32> |
| 20 | + %14 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> |
| 21 | + %15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> |
| 22 | + tt.store %15, %13, %6 : tensor<1024x!tt.ptr<f32>> |
| 23 | + tt.return |
| 24 | + } |
| 25 | +} |
| 26 | + |
| 27 | +// CHECK: tt.func @add_kernel(%[[arg0:.+]]: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %[[arg1:.+]]: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %[[arg2:.+]]: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %[[arg3:.+]]: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %[[arg4:.+]]: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %[[arg5:.+]]: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { |
| 28 | +// CHECK-NEXT: %[[c1024_i32:.+]] = arith.constant 1024 : i32 |
| 29 | +// CHECK-NEXT: %[[v0:.+]] = tt.get_program_id x : i32 |
| 30 | +// CHECK-NEXT: %[[v1:.+]] = arith.muli %[[v0]], %[[c1024_i32]] : i32 |
| 31 | +// CHECK-NEXT: %[[v2:.+]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> |
| 32 | +// CHECK-NEXT: %[[v3:.+]] = tt.splat %[[v1]] : i32 -> tensor<1024xi32> |
| 33 | +// CHECK-NEXT: %[[v4:.+]] = arith.addi %[[v3]], %[[v2]] : tensor<1024xi32> |
| 34 | +// CHECK-NEXT: %[[v5:.+]] = tt.splat %[[arg5]] : i32 -> tensor<1024xi32> |
| 35 | +// CHECK-NEXT: %[[v6:.+]] = arith.cmpi slt, %[[v4]], %[[v5]] : tensor<1024xi32> |
| 36 | +// CHECK-NEXT: %[[v7:.+]] = tt.splat %[[arg1]] : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> |
| 37 | +// CHECK-NEXT: %[[v8:.+]] = tt.splat %[[arg0]] : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> |
| 38 | +// CHECK-NEXT: %[[v9:.+]] = tt.addptr %[[v7]], %[[v4]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> |
| 39 | +// CHECK-NEXT: %[[v10:.+]] = tt.addptr %[[v8]], %[[v4]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> |
| 40 | +// CHECK-NEXT: %[[v11:.+]] = tt.load %[[v9]], %[[v6]] : tensor<1024x!tt.ptr<f32>> |
| 41 | +// CHECK-NEXT: %[[v12:.+]] = tt.load %[[v10]], %[[v6]] : tensor<1024x!tt.ptr<f32>> |
| 42 | +// CHECK-NEXT: %[[v13:.+]] = tt.splat %[[arg2]] : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> |
| 43 | +// CHECK-NEXT: %[[v14:.+]] = tt.addptr %[[v13]], %[[v4]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> |
| 44 | +// CHECK-NEXT: %[[v15:.+]] = tt.load %[[v14]], %[[v6]] : tensor<1024x!tt.ptr<f32>> |
| 45 | +// CHECK-NEXT: %[[v16:.+]] = arith.addf %[[v12]], %[[v15]] : tensor<1024xf32> |
| 46 | +// CHECK-NEXT: %[[v17:.+]] = tt.splat %[[arg4]] : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> |
| 47 | +// CHECK-NEXT: %[[v18:.+]] = tt.splat %[[arg3]] : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> |
| 48 | +// CHECK-NEXT: %[[v19:.+]] = tt.addptr %[[v17]], %[[v4]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> |
| 49 | +// CHECK-NEXT: %[[v20:.+]] = tt.addptr %[[v18]], %[[v4]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> |
| 50 | +// CHECK-NEXT: tt.store %[[v19]], %[[v11]], %[[v6]] : tensor<1024x!tt.ptr<f32>> |
| 51 | +// CHECK-NEXT: tt.store %[[v20]], %[[v16]], %[[v6]] : tensor<1024x!tt.ptr<f32>> |
| 52 | +// CHECK-NEXT: tt.return |
| 53 | +// CHECK-NEXT: } |
0 commit comments