|
| 1 | +// RUN: gc-gpu-runner --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils %s | FileCheck %s |
| 2 | + |
| 3 | +!dtype=f16 |
| 4 | +!input_memref_type=memref<2x7x32x128x!dtype> |
| 5 | +!input_tensor_type=tensor<2x7x32x128x!dtype> |
| 6 | +!output_memref_type=memref<2x32x7x128x!dtype> |
| 7 | +!output_tensor_type=tensor<2x32x7x128x!dtype> |
| 8 | +!cos_sin_cache_memref_type=memref<1x1x2048x128x!dtype> |
| 9 | +!cos_sin_cache_tensor_type=tensor<1x1x2048x128x!dtype> |
| 10 | +!cos_sin_cache_tensor_shrink_type=tensor<1x1x7x128x!dtype> |
| 11 | +!pos_ids_memref_type=memref<1x7xindex> |
| 12 | +!pos_ids_tensor_type=tensor<1x7xindex> |
| 13 | +#map = affine_map<(xi, yi, zi) -> ((xi * 3 * 4 + yi * 4 + zi) * 2)> |
| 14 | +module @fragment_name { |
| 15 | +memref.global "private" constant @_cos_cache : !cos_sin_cache_memref_type = dense<3.000000e+00> |
| 16 | +memref.global "private" constant @_sin_cache : !cos_sin_cache_memref_type = dense<2.000000e+00> |
| 17 | +memref.global "private" constant @_iinput_const : !input_memref_type = dense<3.000000e+00> |
| 18 | +memref.global "private" constant @_ipos_ids_const : !pos_ids_memref_type = dense<1> |
| 19 | +memref.global "private" constant @_ipos_id_end_const : memref<1xindex> = dense<1> |
| 20 | +func.func @RoPE(%iinput: !input_memref_type, %ipos_ids: !pos_ids_memref_type, %ipos_id_end: memref<1xindex>, %out: !output_memref_type) { |
| 21 | + %input = bufferization.to_tensor %iinput restrict : !input_memref_type |
| 22 | + %cos_cache = memref.get_global @_cos_cache : !cos_sin_cache_memref_type |
| 23 | + %sin_cache = memref.get_global @_sin_cache : !cos_sin_cache_memref_type |
| 24 | + %cos_cache_tensor = bufferization.to_tensor %cos_cache restrict : !cos_sin_cache_memref_type |
| 25 | + %sin_cache_tensor = bufferization.to_tensor %sin_cache restrict : !cos_sin_cache_memref_type |
| 26 | + %pos_ids = bufferization.to_tensor %ipos_ids restrict : !pos_ids_memref_type |
| 27 | + %pos_id_end = bufferization.to_tensor %ipos_id_end restrict : memref<1xindex> |
| 28 | + %3 = tensor.empty(): !output_tensor_type |
| 29 | + |
| 30 | + %transpose_in = linalg.transpose ins(%input: !input_tensor_type) outs(%3:!output_tensor_type) permutation = [0, 2, 1, 3] |
| 31 | + |
| 32 | + %c0 = arith.constant 0 : index |
| 33 | + %c3 = arith.constant 3 : index |
| 34 | + %cos_cache_slice = tensor.extract_slice %cos_cache_tensor[0, 0, 0, 0] [1, 1, 7, 128] [1, 1, 1, 1] : !cos_sin_cache_tensor_type to !cos_sin_cache_tensor_shrink_type |
| 35 | + %cos_cache_slice2 = tensor.collapse_shape %cos_cache_slice [[0, 1], [2],[3]] : tensor<1x1x7x128x!dtype> into tensor<1x7x128x!dtype> |
| 36 | + %cos_cache_slice3 = tensor.collapse_shape %cos_cache_slice2 [[0, 1], [2]] : tensor<1x7x128x!dtype> into tensor<7x128x!dtype> |
| 37 | + %pos_ids_index=tensor.expand_shape %pos_ids [[0],[1,2]] output_shape [1, 7, 1] : tensor<1x7xindex> into tensor<1x7x1xindex> |
| 38 | + |
| 39 | + %cos_cache_slice4 = tensor.gather %cos_cache_slice3[%pos_ids_index] gather_dims([0]) : (tensor<7x128x!dtype>, tensor<1x7x1xindex>) -> tensor<1x7x128x!dtype> |
| 40 | + |
| 41 | + %cos_cache_slice5 = tensor.expand_shape %cos_cache_slice4 [[0,1],[2],[3]] output_shape [1,1,7,128] : tensor<1x7x128x!dtype> into tensor<1x1x7x128x!dtype> |
| 42 | + %cos_cache_slice6 = tensor.collapse_shape %cos_cache_slice5 [[0,1,2],[3]] : tensor<1x1x7x128x!dtype> into tensor<7x128x!dtype> |
| 43 | + |
| 44 | + %cos_cache_slice7 = linalg.broadcast ins(%cos_cache_slice6: tensor<7x128x!dtype>) outs(%3: !output_tensor_type) dimensions = [0, 1] |
| 45 | + %input_apply_cos_cache = linalg.mul ins(%transpose_in, %cos_cache_slice7: !output_tensor_type, !output_tensor_type) outs(%3: !output_tensor_type) -> !output_tensor_type |
| 46 | + |
| 47 | + %head_dim = tensor.dim %transpose_in, %c3 : !output_tensor_type |
| 48 | + %c2 = arith.constant 2 : index |
| 49 | + %half_head_dim = arith.floordivsi %head_dim, %c2 : index |
| 50 | + %transpose_input_first_half = tensor.extract_slice %transpose_in[0, 0, 0, 0][2, 32, 7, 64][1,1,1,1] : !output_tensor_type to tensor<2x32x7x64x!dtype> |
| 51 | + %transpose_input_second_half = tensor.extract_slice %transpose_in[0, 0, 0, %half_head_dim][2, 32, 7, 64][1,1,1,1] : !output_tensor_type to tensor<2x32x7x64x!dtype> |
| 52 | + %cnegative1 = arith.constant dense<-1.000000e+00> : tensor<2x32x7x64x!dtype> |
| 53 | + %empty_tensor = tensor.empty() : tensor<2x32x7x64x!dtype> |
| 54 | + %transpose_input_second_half_opposite = linalg.mul ins(%transpose_input_second_half, %cnegative1: tensor<2x32x7x64x!dtype>, tensor<2x32x7x64x!dtype>) outs(%empty_tensor: tensor<2x32x7x64x!dtype>) -> tensor<2x32x7x64x!dtype> |
| 55 | + |
| 56 | + %transformed_input = tensor.concat dim(3) %transpose_input_second_half_opposite, %transpose_input_first_half : (tensor<2x32x7x64x!dtype>, tensor<2x32x7x64x!dtype>) -> !output_tensor_type |
| 57 | + |
| 58 | + %sin_cache_slice = tensor.extract_slice %sin_cache_tensor[0, 0, 0, 0] [1, 1, 7, 128] [1, 1, 1, 1] : !cos_sin_cache_tensor_type to !cos_sin_cache_tensor_shrink_type |
| 59 | + %sin_cache_slice2 = tensor.collapse_shape %sin_cache_slice [[0, 1], [2],[3]] : tensor<1x1x7x128x!dtype> into tensor<1x7x128x!dtype> |
| 60 | + %sin_cache_slice3 = tensor.collapse_shape %sin_cache_slice2 [[0, 1], [2]] : tensor<1x7x128x!dtype> into tensor<7x128x!dtype> |
| 61 | + %sin_cache_slice4 = tensor.gather %sin_cache_slice3[%pos_ids_index] gather_dims([0]) : (tensor<7x128x!dtype>, tensor<1x7x1xindex>) -> tensor<1x7x128x!dtype> |
| 62 | + |
| 63 | + %sin_cache_slice5 = tensor.expand_shape %sin_cache_slice4 [[0,1],[2],[3]] output_shape [1,1,7,128] : tensor<1x7x128x!dtype> into tensor<1x1x7x128x!dtype> |
| 64 | + %sin_cache_slice6 = tensor.collapse_shape %sin_cache_slice5 [[0,1,2],[3]] : tensor<1x1x7x128x!dtype> into tensor<7x128x!dtype> |
| 65 | + |
| 66 | + %sin_cache_slice7 = linalg.broadcast ins(%sin_cache_slice6: tensor<7x128x!dtype>) outs(%3: !output_tensor_type) dimensions = [0, 1] |
| 67 | + %input_apply_sin_cache = linalg.mul ins(%transformed_input, %sin_cache_slice7: !output_tensor_type, !output_tensor_type) outs(%3: !output_tensor_type) -> !output_tensor_type |
| 68 | + |
| 69 | + %result = linalg.add ins(%input_apply_cos_cache, %input_apply_sin_cache: !output_tensor_type, !output_tensor_type) outs(%3: !output_tensor_type) -> !output_tensor_type |
| 70 | + bufferization.materialize_in_destination %result in restrict writable %out : (!output_tensor_type, !output_memref_type) -> () |
| 71 | + return |
| 72 | +} |
| 73 | + |
| 74 | +func.func @main() { |
| 75 | + %inp = memref.get_global @_iinput_const : !input_memref_type |
| 76 | + %ipos_ids = memref.get_global @_ipos_ids_const : !pos_ids_memref_type |
| 77 | + %ipos_id_end = memref.get_global @_ipos_id_end_const : memref<1xindex> |
| 78 | + |
| 79 | + %out = memref.alloc() {alignment = 64 : i64} : !output_memref_type |
| 80 | + |
| 81 | + func.call @RoPE(%inp, %ipos_ids, %ipos_id_end, %out) : (!input_memref_type, !pos_ids_memref_type, memref<1xindex>, !output_memref_type) -> () |
| 82 | + |
| 83 | + %out_subview = memref.subview %out[0, 0, 0, 0] [2, 1, 1, 1] [1, 1, 1, 1] : !output_memref_type to memref<2xf16, strided<[28672]>> |
| 84 | + %cast = memref.cast %out_subview : memref<2xf16, strided<[28672]>> to memref<*xf16> |
| 85 | + call @printMemrefF16(%cast) : (memref<*xf16>) -> () |
| 86 | + |
| 87 | + return |
| 88 | +} |
| 89 | + |
| 90 | +func.func private @printMemrefF16(%ptr : memref<*xf16>) |
| 91 | +} |
| 92 | + |
| 93 | +// CHECK: Unranked Memref base@{{(0x)?[-0-9a-fA-F]*}} |
| 94 | +// CHECK-SAME: rank = 1 offset = 0 sizes = [2] strides = [28672] data = |
| 95 | +// CHECK-NEXT: [3, 3] |
0 commit comments