|
13 | 13 | #map3 = affine_map<(d0, d1) -> (d0, d1)> |
14 | 14 | module @gemm { |
15 | 15 | func.func @main() { |
16 | | - %0= arith.constant dense<[[1, 2, 3], [1, 1, 1], [1, 1, 1]]>:tensor<3x3xi8> |
17 | | - %1 = arith.constant dense<[[1, 1, 1], [1, 2, 3], [1, 1, 1]]>:tensor<3x3xi8> |
| 16 | + %0= arith.constant dense<[[1, 1, 1], [1, 1, 2], [3, 3, 3]]>:tensor<3x3xi8> |
| 17 | + %1 = arith.constant dense<[[10, 11, 12], [13, 14, 15], [16, 17, 18]]>:tensor<3x3xi8> |
18 | 18 | %2= arith.constant dense<[[1, 1, 1], [1, 1, 1], [1, 2, 3]]>:tensor<3x3xi8> |
19 | 19 | %3 = call @test(%0,%1,%2) : (tensor<3x3xi8>,tensor<3x3xi8>,tensor<3x3xi8>) -> tensor<3x3xi8> |
20 | | - %unranked = tensor.cast %3 : tensor<3x3xi8>to tensor<*xi8> |
21 | | - call @printMemrefI8(%unranked) : (tensor<*xi8>) -> () |
22 | | - // CHECK: |
| 20 | + %4 = call @castI8toI32(%3): (tensor<3x3xi8>) -> tensor<3x3xi32> |
| 21 | + %unranked = tensor.cast %4 : tensor<3x3xi32>to tensor<*xi32> |
| 22 | + call @printMemrefI32(%unranked) : (tensor<*xi32>) -> () |
| 23 | + // CHECK: Unranked Memref base@ = {{(0x)?[-9a-f]*}} |
| 24 | + // CHECK-NEXT: [40, 43, 46] |
| 25 | + // CHECK-NEXT: [56, 60, 64] |
| 26 | + // CHECK-NEXT: [118, 128, 138] |
23 | 27 | return |
24 | 28 | } |
25 | | -func.func private @printMemrefI8(tensor<*xi8>) |
| 29 | + |
| 30 | +func.func @castI8toI32(%arg0: tensor<3x3xi8>) -> tensor<3x3xi32> { |
| 31 | + %1 = tensor.empty() : tensor<3x3xi32> |
| 32 | + %2 = linalg.generic {indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel"]} |
| 33 | + ins(%arg0: tensor<3x3xi8>) |
| 34 | + outs(%1 : tensor<3x3xi32>) |
| 35 | + attrs = {iterator_ranges = [3, 3]} { |
| 36 | + ^bb0(%arg1: i8, %arg2: i32): |
| 37 | + %3 = arith.extui %arg1: i8 to i32 |
| 38 | + linalg.yield %3 : i32 |
| 39 | + } -> tensor<3x3xi32> |
| 40 | + return %2: tensor<3x3xi32> |
| 41 | +} |
| 42 | + |
| 43 | +func.func private @printMemrefI32(tensor<*xi32>) attributes { llvm.emit_c_interface } |
26 | 44 | func.func @test(%arg0: tensor<3x3xi8>, %arg1: tensor<3x3xi8>, %arg2: tensor<3x3xi8>) -> tensor<3x3xi8> { |
27 | 45 | %c0_i8 = arith.constant 0 : i8 |
28 | 46 | %0 = tensor.empty() : tensor<3x3xi8> |
|
0 commit comments