@@ -26,29 +26,20 @@ func.func @RoPE(%iinput: !input_memref_type, %ipos_ids: !pos_ids_memref_type, %i
2626 %pos_ids = bufferization.to_tensor %ipos_ids restrict : !pos_ids_memref_type
2727 %pos_id_end = bufferization.to_tensor %ipos_id_end restrict : memref <1 xindex >
2828 %3 = tensor.empty (): !output_tensor_type
29- //call @stopTimerMy() : () -> ()
29+
3030 %transpose_in = linalg.transpose ins (%input: !input_tensor_type ) outs (%3: !output_tensor_type ) permutation = [0 , 2 , 1 , 3 ]
3131
32- //call @startTimerMy() : () -> ()
3332 %c0 = arith.constant 0 : index
3433 %c3 = arith.constant 3 : index
3534 %cos_cache_slice = tensor.extract_slice %cos_cache_tensor [0 , 0 , 0 , 0 ] [1 , 1 , 7 , 128 ] [1 , 1 , 1 , 1 ] : !cos_sin_cache_tensor_type to !cos_sin_cache_tensor_shrink_type
3635 %cos_cache_slice2 = tensor.collapse_shape %cos_cache_slice [[0 , 1 ], [2 ],[3 ]] : tensor <1 x1 x7 x128 x!dtype > into tensor <1 x7 x128 x!dtype >
3736 %cos_cache_slice3 = tensor.collapse_shape %cos_cache_slice2 [[0 , 1 ], [2 ]] : tensor <1 x7 x128 x!dtype > into tensor <7 x128 x!dtype >
3837 %pos_ids_index =tensor.expand_shape %pos_ids [[0 ],[1 ,2 ]] output_shape [1 , 7 , 1 ] : tensor <1 x7 xindex > into tensor <1 x7 x1 xindex >
39- //call @stopTimerMy() : () -> ()
40-
41- //call @startTimerMy() : () -> ()
4238
4339 %cos_cache_slice4 = tensor.gather %cos_cache_slice3 [%pos_ids_index ] gather_dims ([0 ]) : (tensor <7 x128 x!dtype >, tensor <1 x7 x1 xindex >) -> tensor <1 x7 x128 x!dtype >
4440
45- //call @stopTimerMy() : () -> ()
46-
47- //call @startTimerMy() : () -> ()
48- // %cos_cache_slice4 = tensor.expand_shape %cos_cache_slice3[[0,1],[2]] output_shape [1,7,128] : tensor<7x128x!dtype> into tensor<1x7x128x!dtype>
4941 %cos_cache_slice5 = tensor.expand_shape %cos_cache_slice4 [[0 ,1 ],[2 ],[3 ]] output_shape [1 ,1 ,7 ,128 ] : tensor <1 x7 x128 x!dtype > into tensor <1 x1 x7 x128 x!dtype >
5042 %cos_cache_slice6 = tensor.collapse_shape %cos_cache_slice5 [[0 ,1 ,2 ],[3 ]] : tensor <1 x1 x7 x128 x!dtype > into tensor <7 x128 x!dtype >
51- //call @stopTimerMy() : () -> ()
5243
5344 %cos_cache_slice7 = linalg.broadcast ins (%cos_cache_slice6: tensor <7 x128 x!dtype >) outs (%3: !output_tensor_type ) dimensions = [0 , 1 ]
5445 %input_apply_cos_cache = linalg.mul ins (%transpose_in , %cos_cache_slice7: !output_tensor_type , !output_tensor_type ) outs (%3: !output_tensor_type ) -> !output_tensor_type
0 commit comments