Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions lib/gc/Transforms/GPU/LinalgToXeGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1312,6 +1312,105 @@ static FailureOr<Value> findAndReplaceTranspose(const Value &matmulOperand,
defOp, "No transpose operation producing the operand was found");
}

// Checks whether the given linalgOp operand is produced by a
// `linalg::BroadcastOp` that can be replaced by a simple subview
// (for example broadcast: 7x128 -> 1x1x7x128) and ensures that
// the broadcast result is only used by linalgOp in question.
//
// If a valid `linalg::BroadcastOp` is found, the function removes it
// and returns the operand of the `linalg::BroadcastOp` as the new
// linalgOp operand. Otherwise returns the original operand.
static Value findAndReplaceBroadcast(linalg::LinalgOp linalgOp,
size_t operandIdx,
PatternRewriter &rewriter) {
Comment on lines +1315 to +1325
Copy link
Contributor Author

@dchigarev dchigarev Dec 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may ask: why can't we just lower every such linalg.broadcast into something like memref.expand_shape via a separate pattern, instead of doing this "find a broadcast that produces an operand of a linalg-op that we're already lowering to xegpu" quest?

The problem is that the memref-to-spirv pass supports a very limited set of memref ops that can be lowered. It's basically only memref.subview that is supported and we can't expand memref shapes with it. So we can't just replace linalg.broadcast with memref.expand_shape since our pipeline shall fail then:

// --------- before LinalgToXeGPU
%inp = memref.alloc() : memref<7x128xf16>
%out = memref.alloc() : memref<1x7x128xf16>
linalg.broadcast ins(%inp) out(%out)
linalg.add ins(%out, ...)

// --------- after LinalgToXeGPU

// BroadcastToExpandShapePattern:
%inp = memref.alloc() : memref<7x128xf16>
%out = memref.expand_shape %inp : memref<7x128> to memref<1x7x128> // <-- this will crash our pipeline

// ElementWiseToXeGPUPattern:
%out_squeeze = memref.subview %out : memref<1x7x128> to memref<7x128>
%desc = xegpu.create_tensor_desc %out_squeeze 
...

And although a human eye can see here, that the memref.expand_shape + memref.subview can be eliminated, none of the upstream passes can do that. Even if the expand_shape-subview-merger pass existed, we still could not guarantee, that the memref.expand_shape is always followed by a rank-reducing memref.subview that it can be merged with. Example:

// --------- before LinalgToXeGPU
%inp = memref.alloc() : memref<7x128xf16>
%out = memref.alloc() : memref<1x7x128xf16>
linalg.broadcast ins(%inp) out(%out)
linalg.trickyOp ins(%out, ...)

// --------- after LinalgToXeGPU

// BroadcastToExpandShapePattern:
%inp = memref.alloc() : memref<7x128xf16>
%out = memref.expand_shape %inp : memref<7x128> to memref<1x7x128> // <-- this will crash our pipeline

// 'linalg.trickyOp' is not supported by LinalgToXeGPU pass
// no rank-reducing memref.subview to merge 'expand_shape' with
linalg.trickyOp ins(%out, ...)
...

// --------- after LinalgToLoops
// BroadcastToExpandShapePattern:
%inp = memref.alloc() : memref<7x128xf16>
%out = memref.expand_shape %inp : memref<7x128> to memref<1x7x128> // <-- this will crash our pipeline

for {
   for {
      for {
          %outScalar = memref.load %out
          arith.trickyOp %outScalar
          ...
       }
   }
}
...

So the only option we're left with is to only "lower" linalg.broadcast when it produces an operand of a linalgOp that we're lowering to xegpu right now, and only do so by simply erasing broadcastOp and forwarding its input to the input of the linalgOp in question. Example:

// --------- before LinalgToXeGPU
%inp = memref.alloc() : memref<7x128xf16>
%out = memref.alloc() : memref<1x7x128xf16>
linalg.broadcast ins(%inp) out(%out)
linalg.add ins(%out, ...)

// --------- after LinalgToXeGPU
// ElementWiseToXeGPUPattern:
%inp = memref.alloc() : memref<7x128xf16>
%desc = xegpu.create_tensor_desc %inp
...

auto operand = linalgOp.getDpsInputs()[operandIdx];
auto operandParent = operand;

// walk over the 'Value' users and verify that it's only used by 'ops'
std::function<bool(Value, SmallVector<linalg::LinalgOp> &)> onlyUsedByOp =
[&onlyUsedByOp](Value value, SmallVector<linalg::LinalgOp> &ops) -> bool {
bool result = true;
for (auto user : value.getUsers()) {
if (auto linalgOpUser = dyn_cast<linalg::LinalgOp>(user))
result &= std::find(ops.begin(), ops.end(), linalgOpUser) !=
ops.end(); // linalgOpUser == op;
else if (isa<memref::DeallocOp>(user))
continue; // allow deallocs as users
else if (auto subview = dyn_cast<memref::SubViewOp>(user))
result &= onlyUsedByOp(subview.getResult(), ops);
else
return false;
}
return result;
};

linalg::BroadcastOp broadcastOp = nullptr;
while (auto defOp = operandParent.getDefiningOp()) {
for (auto x : defOp->getUsers()) {
if (!isa<linalg::BroadcastOp>(x))
continue;

if (broadcastOp) {
rewriter.notifyMatchFailure(broadcastOp,
"Only one broadcast operation is allowed");
return operand;
}

broadcastOp = dyn_cast<linalg::BroadcastOp>(x);
auto broadcastRes = broadcastOp.getDpsInits()[0];
SmallVector<linalg::LinalgOp> ops({linalgOp, broadcastOp});

// verify that there are no other users of the broadcast result
// other than the linalgOp in question
if (!onlyUsedByOp(broadcastRes, ops)) {
rewriter.notifyMatchFailure(
broadcastOp, "Broadcast result is used by more than one operation");
return operand;
}
break;
}

if (defOp->getOperands().size() == 0)
break;

operandParent = defOp->getOperand(0);
}
if (!broadcastOp) {
rewriter.notifyMatchFailure(
linalgOp, "No broadcast operation producing the operand was found");
return operand;
}

auto brInp = broadcastOp.getDpsInputs()[0];
auto brOut = broadcastOp.getDpsInits()[0];

auto inpType = dyn_cast<MemRefType>(brInp.getType());
auto outType = dyn_cast<MemRefType>(brOut.getType());
if (!inpType || !outType)
return operand;

auto inpShape = inpType.getShape();
auto outShape = outType.getShape();

if (inpShape.size() < 2) {
rewriter.notifyMatchFailure(broadcastOp, "Only nD broadcast is supported");
return operand;
}

if (!utils::canSqueezeDims(inpShape) || !utils::canSqueezeDims(outShape)) {
rewriter.notifyMatchFailure(broadcastOp,
"Can't squeeze broadcast operands to 2D");
return operand;
}

auto res = utils::reduceMemrefDims(rewriter, broadcastOp.getLoc(), brInp);
if (failed(res))
return operand;

rewriter.eraseOp(broadcastOp);
return res.value();
}

// Create XeGPU DPAS kernel out of GEMM-like operation.
static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,
ArrayRef<int64_t> dpasTile, int kTile,
Expand Down Expand Up @@ -1690,7 +1789,9 @@ LogicalResult createEltwiseKernel(linalg::LinalgOp linalgOp,

// Create descriptors and load values for all inputs.
SmallVector<SmallVector<Value>> loadedInputs;
size_t operandIdx = 0;
for (auto input : linalgOp.getDpsInputs()) {
input = findAndReplaceBroadcast(linalgOp, operandIdx, rewriter);
SmallVector<Value> inputTiles =
createDescriptorTiles(rewriter, loc, input, outputShape, tileShape);

Expand All @@ -1699,6 +1800,7 @@ LogicalResult createEltwiseKernel(linalg::LinalgOp linalgOp,
/*vnniConf=*/std::nullopt,
/*transpose=*/nullptr, /*transpose_bit=*/nullptr);
loadedInputs.push_back(loadedVals);
operandIdx++;
}

// Extract SIMD sized sub-tiles from loaded tiles.
Expand Down
1 change: 1 addition & 0 deletions lib/gc/Transforms/GPU/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ void populateGPUPipeline(OpPassManager &pm,

pm.addPass(createDecomposeTensorOperation());
pm.addNestedPass<func::FuncOp>(createGpuTilingAndFusion());
pm.addPass(createCanonicalizerPass());
Copy link
Contributor Author

@dchigarev dchigarev Dec 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should do the 'cleaning' right after the tiling. Otherwise the bufferization pass may produce memref.cast ops that can not be lowered by memref-to-spirv


pm.addPass(bufferization::createEmptyTensorEliminationPass());
pm.addPass(bufferization::createEmptyTensorToAllocTensorPass());
Expand Down
109 changes: 109 additions & 0 deletions test/mlir/test/gc/Transforms/GPU/linalg-to-xegpu-broadcast-fold.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
// RUN: gc-opt %s -linalg-to-xegpu="dpas-tile=8,16,16 k-tile=16" -canonicalize -split-input-file | FileCheck %s

// CHECK-LABEL: func.func @broadcast_eliminate_2d
func.func @broadcast_eliminate_2d() {
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c4 = arith.constant 4 : index

// CHECK: %[[MEMREF_0:.*]] = memref.alloc() : memref<7x128xf16>
%0 = memref.alloc() : memref<7x128xf16>
// CHECK: %[[MEMREF_2:.*]] = memref.alloc() : memref<1x1x7x128xf16>
%2 = memref.alloc() : memref<1x1x7x128xf16>
// CHECK: %[[MEMREF_3:.*]] = memref.alloc() : memref<1x1x7x128xf16>
%3 = memref.alloc() : memref<1x1x7x128xf16>

gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg11 = %c2, %arg12 = %c4, %arg13 = %c1) threads(%arg6, %arg7, %arg8) in (%arg14 = %c4, %arg15 = %c1, %arg16 = %c1) {
// CHECK-NOT: memref.alloc() : memref<4x1x7x128xf16, 3>
%slm_base = memref.alloc() : memref<4x1x7x128xf16, 3>
%1 = memref.subview %slm_base[%arg6, 0, 0, 0] [1, 1, 7, 128] [1, 1, 1, 1] : memref<4x1x7x128xf16, 3> to memref<1x1x7x128xf16, strided<[896, 896, 128, 1], offset: ?>, 3>

// CHECK-NOT: linalg.broadcast
linalg.broadcast ins(%0 : memref<7x128xf16>) outs(%1 : memref<1x1x7x128xf16, strided<[896, 896, 128, 1], offset: ?>, 3>) dimensions = [0, 1]
// CHECK: xegpu.create_nd_tdesc %[[MEMREF_0]]
linalg.add ins(%1, %2 : memref<1x1x7x128xf16, strided<[896, 896, 128, 1], offset: ?>, 3>, memref<1x1x7x128xf16>) outs(%3 : memref<1x1x7x128xf16>)
gpu.terminator
}
return
}

// -----

// CHECK-LABEL: func.func @broadcast_eliminate
func.func @broadcast_eliminate_3d() {
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c4 = arith.constant 4 : index

// CHECK: %[[MEMREF_0:.*]] = memref.alloc() : memref<1x7x128xf16>
%0 = memref.alloc() : memref<1x7x128xf16>
// CHECK: %[[MEMREF_2:.*]] = memref.alloc() : memref<1x1x7x128xf16>
%2 = memref.alloc() : memref<1x1x7x128xf16>
// CHECK: %[[MEMREF_3:.*]] = memref.alloc() : memref<1x1x7x128xf16>
%3 = memref.alloc() : memref<1x1x7x128xf16>

gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg11 = %c2, %arg12 = %c4, %arg13 = %c1) threads(%arg6, %arg7, %arg8) in (%arg14 = %c4, %arg15 = %c1, %arg16 = %c1) {
// CHECK-NOT: memref.alloc() : memref<4x1x7x128xf16, 3>
%slm_base = memref.alloc() : memref<4x1x7x128xf16, 3>
%1 = memref.subview %slm_base[%arg6, 0, 0, 0] [1, 1, 7, 128] [1, 1, 1, 1] : memref<4x1x7x128xf16, 3> to memref<1x1x7x128xf16, strided<[896, 896, 128, 1], offset: ?>, 3>

// CHECK-NOT: linalg.broadcast
linalg.broadcast ins(%0 : memref<1x7x128xf16>) outs(%1 : memref<1x1x7x128xf16, strided<[896, 896, 128, 1], offset: ?>, 3>) dimensions = [0]
// Squeezing the %0 before passing to 'linalg.add'
// CHECK: %[[MEMREF0_SQUEEZ:.+]] = memref.subview %[[MEMREF_0]][0, 0, 0] [1, 7, 128] [1, 1, 1] :
// CHECK-SAME: memref<1x7x128xf16> to memref<7x128xf16, strided<[128, 1]>>
// CHECK: xegpu.create_nd_tdesc %[[MEMREF0_SQUEEZ]]
linalg.add ins(%1, %2 : memref<1x1x7x128xf16, strided<[896, 896, 128, 1], offset: ?>, 3>, memref<1x1x7x128xf16>) outs(%3 : memref<1x1x7x128xf16>)
gpu.terminator
}
return
}

// -----

// CHECK-LABEL: func.func @complex_broadcast
func.func @complex_broadcast_3d() {
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c4 = arith.constant 4 : index

// CHECK: %[[MEMREF_0:.*]] = memref.alloc() : memref<7x128xf16>
%0 = memref.alloc() : memref<7x128xf16>
// CHECK: %[[MEMREF_1:.*]] = memref.alloc() : memref<7x7x128xf16>
%1 = memref.alloc() : memref<7x7x128xf16>
// CHECK: %[[MEMREF_2:.*]] = memref.alloc() : memref<7x7x128xf16>
%2 = memref.alloc() : memref<7x7x128xf16>
// CHECK: %[[MEMREF_3:.*]] = memref.alloc() : memref<7x7x128xf16>
%3 = memref.alloc() : memref<7x7x128xf16>

gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg11 = %c2, %arg12 = %c4, %arg13 = %c1) threads(%arg6, %arg7, %arg8) in (%arg14 = %c4, %arg15 = %c1, %arg16 = %c1) {
// This broadcast can't be replaced by a single memref.subview. Can't remove it
// CHECK: linalg.broadcast
linalg.broadcast ins(%0 : memref<7x128xf16>) outs(%1 : memref<7x7x128xf16>) dimensions = [0]
linalg.add ins(%1, %2 : memref<7x7x128xf16>, memref<7x7x128xf16>) outs(%3 : memref<7x7x128xf16>)
gpu.terminator
}
return
}

// -----

// CHECK-LABEL: func.func @single_broadcast
func.func @single_broadcast() {
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c4 = arith.constant 4 : index

// CHECK: %[[MEMREF_0:.*]] = memref.alloc() : memref<7x128xf16>
%0 = memref.alloc() : memref<7x128xf16>
// CHECK: %[[MEMREF_1:.*]] = memref.alloc() : memref<1x1x7x128xf16>
%1 = memref.alloc() : memref<1x1x7x128xf16>

gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg11 = %c2, %arg12 = %c4, %arg13 = %c1) threads(%arg6, %arg7, %arg8) in (%arg14 = %c4, %arg15 = %c1, %arg16 = %c1) {
// broadcast result is not an input of any xegpu operation, we can't lower it
// CHECK: linalg.broadcast
linalg.broadcast ins(%0 : memref<7x128xf16>) outs(%1 : memref<1x1x7x128xf16>) dimensions = [0, 1]
gpu.terminator
}
return
}
95 changes: 95 additions & 0 deletions test/mlir/test/gc/gpu-runner/XeGPU/rope_sanity_test.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
// RUN: gc-gpu-runner --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils %s | FileCheck %s

!dtype=f16
!input_memref_type=memref<2x7x32x128x!dtype>
!input_tensor_type=tensor<2x7x32x128x!dtype>
!output_memref_type=memref<2x32x7x128x!dtype>
!output_tensor_type=tensor<2x32x7x128x!dtype>
!cos_sin_cache_memref_type=memref<1x1x2048x128x!dtype>
!cos_sin_cache_tensor_type=tensor<1x1x2048x128x!dtype>
!cos_sin_cache_tensor_shrink_type=tensor<1x1x7x128x!dtype>
!pos_ids_memref_type=memref<1x7xindex>
!pos_ids_tensor_type=tensor<1x7xindex>
#map = affine_map<(xi, yi, zi) -> ((xi * 3 * 4 + yi * 4 + zi) * 2)>
module @fragment_name {
memref.global "private" constant @_cos_cache : !cos_sin_cache_memref_type = dense<3.000000e+00>
memref.global "private" constant @_sin_cache : !cos_sin_cache_memref_type = dense<2.000000e+00>
memref.global "private" constant @_iinput_const : !input_memref_type = dense<3.000000e+00>
memref.global "private" constant @_ipos_ids_const : !pos_ids_memref_type = dense<1>
memref.global "private" constant @_ipos_id_end_const : memref<1xindex> = dense<1>
func.func @RoPE(%iinput: !input_memref_type, %ipos_ids: !pos_ids_memref_type, %ipos_id_end: memref<1xindex>, %out: !output_memref_type) {
%input = bufferization.to_tensor %iinput restrict : !input_memref_type
%cos_cache = memref.get_global @_cos_cache : !cos_sin_cache_memref_type
%sin_cache = memref.get_global @_sin_cache : !cos_sin_cache_memref_type
%cos_cache_tensor = bufferization.to_tensor %cos_cache restrict : !cos_sin_cache_memref_type
%sin_cache_tensor = bufferization.to_tensor %sin_cache restrict : !cos_sin_cache_memref_type
%pos_ids = bufferization.to_tensor %ipos_ids restrict : !pos_ids_memref_type
%pos_id_end = bufferization.to_tensor %ipos_id_end restrict : memref<1xindex>
%3 = tensor.empty(): !output_tensor_type

%transpose_in = linalg.transpose ins(%input: !input_tensor_type) outs(%3:!output_tensor_type) permutation = [0, 2, 1, 3]

%c0 = arith.constant 0 : index
%c3 = arith.constant 3 : index
%cos_cache_slice = tensor.extract_slice %cos_cache_tensor[0, 0, 0, 0] [1, 1, 7, 128] [1, 1, 1, 1] : !cos_sin_cache_tensor_type to !cos_sin_cache_tensor_shrink_type
%cos_cache_slice2 = tensor.collapse_shape %cos_cache_slice [[0, 1], [2],[3]] : tensor<1x1x7x128x!dtype> into tensor<1x7x128x!dtype>
%cos_cache_slice3 = tensor.collapse_shape %cos_cache_slice2 [[0, 1], [2]] : tensor<1x7x128x!dtype> into tensor<7x128x!dtype>
%pos_ids_index=tensor.expand_shape %pos_ids [[0],[1,2]] output_shape [1, 7, 1] : tensor<1x7xindex> into tensor<1x7x1xindex>

%cos_cache_slice4 = tensor.gather %cos_cache_slice3[%pos_ids_index] gather_dims([0]) : (tensor<7x128x!dtype>, tensor<1x7x1xindex>) -> tensor<1x7x128x!dtype>

%cos_cache_slice5 = tensor.expand_shape %cos_cache_slice4 [[0,1],[2],[3]] output_shape [1,1,7,128] : tensor<1x7x128x!dtype> into tensor<1x1x7x128x!dtype>
%cos_cache_slice6 = tensor.collapse_shape %cos_cache_slice5 [[0,1,2],[3]] : tensor<1x1x7x128x!dtype> into tensor<7x128x!dtype>

%cos_cache_slice7 = linalg.broadcast ins(%cos_cache_slice6: tensor<7x128x!dtype>) outs(%3: !output_tensor_type) dimensions = [0, 1]
%input_apply_cos_cache = linalg.mul ins(%transpose_in, %cos_cache_slice7: !output_tensor_type, !output_tensor_type) outs(%3: !output_tensor_type) -> !output_tensor_type

%head_dim = tensor.dim %transpose_in, %c3 : !output_tensor_type
%c2 = arith.constant 2 : index
%half_head_dim = arith.floordivsi %head_dim, %c2 : index
%transpose_input_first_half = tensor.extract_slice %transpose_in[0, 0, 0, 0][2, 32, 7, 64][1,1,1,1] : !output_tensor_type to tensor<2x32x7x64x!dtype>
%transpose_input_second_half = tensor.extract_slice %transpose_in[0, 0, 0, %half_head_dim][2, 32, 7, 64][1,1,1,1] : !output_tensor_type to tensor<2x32x7x64x!dtype>
%cnegative1 = arith.constant dense<-1.000000e+00> : tensor<2x32x7x64x!dtype>
%empty_tensor = tensor.empty() : tensor<2x32x7x64x!dtype>
%transpose_input_second_half_opposite = linalg.mul ins(%transpose_input_second_half, %cnegative1: tensor<2x32x7x64x!dtype>, tensor<2x32x7x64x!dtype>) outs(%empty_tensor: tensor<2x32x7x64x!dtype>) -> tensor<2x32x7x64x!dtype>

%transformed_input = tensor.concat dim(3) %transpose_input_second_half_opposite, %transpose_input_first_half : (tensor<2x32x7x64x!dtype>, tensor<2x32x7x64x!dtype>) -> !output_tensor_type

%sin_cache_slice = tensor.extract_slice %sin_cache_tensor[0, 0, 0, 0] [1, 1, 7, 128] [1, 1, 1, 1] : !cos_sin_cache_tensor_type to !cos_sin_cache_tensor_shrink_type
%sin_cache_slice2 = tensor.collapse_shape %sin_cache_slice [[0, 1], [2],[3]] : tensor<1x1x7x128x!dtype> into tensor<1x7x128x!dtype>
%sin_cache_slice3 = tensor.collapse_shape %sin_cache_slice2 [[0, 1], [2]] : tensor<1x7x128x!dtype> into tensor<7x128x!dtype>
%sin_cache_slice4 = tensor.gather %sin_cache_slice3[%pos_ids_index] gather_dims([0]) : (tensor<7x128x!dtype>, tensor<1x7x1xindex>) -> tensor<1x7x128x!dtype>

%sin_cache_slice5 = tensor.expand_shape %sin_cache_slice4 [[0,1],[2],[3]] output_shape [1,1,7,128] : tensor<1x7x128x!dtype> into tensor<1x1x7x128x!dtype>
%sin_cache_slice6 = tensor.collapse_shape %sin_cache_slice5 [[0,1,2],[3]] : tensor<1x1x7x128x!dtype> into tensor<7x128x!dtype>

%sin_cache_slice7 = linalg.broadcast ins(%sin_cache_slice6: tensor<7x128x!dtype>) outs(%3: !output_tensor_type) dimensions = [0, 1]
%input_apply_sin_cache = linalg.mul ins(%transformed_input, %sin_cache_slice7: !output_tensor_type, !output_tensor_type) outs(%3: !output_tensor_type) -> !output_tensor_type

%result = linalg.add ins(%input_apply_cos_cache, %input_apply_sin_cache: !output_tensor_type, !output_tensor_type) outs(%3: !output_tensor_type) -> !output_tensor_type
bufferization.materialize_in_destination %result in restrict writable %out : (!output_tensor_type, !output_memref_type) -> ()
return
}

func.func @main() {
%inp = memref.get_global @_iinput_const : !input_memref_type
%ipos_ids = memref.get_global @_ipos_ids_const : !pos_ids_memref_type
%ipos_id_end = memref.get_global @_ipos_id_end_const : memref<1xindex>

%out = memref.alloc() {alignment = 64 : i64} : !output_memref_type

func.call @RoPE(%inp, %ipos_ids, %ipos_id_end, %out) : (!input_memref_type, !pos_ids_memref_type, memref<1xindex>, !output_memref_type) -> ()

%out_subview = memref.subview %out[0, 0, 0, 0] [2, 1, 1, 1] [1, 1, 1, 1] : !output_memref_type to memref<2xf16, strided<[28672]>>
%cast = memref.cast %out_subview : memref<2xf16, strided<[28672]>> to memref<*xf16>
call @printMemrefF16(%cast) : (memref<*xf16>) -> ()

return
}

func.func private @printMemrefF16(%ptr : memref<*xf16>)
}

// CHECK: Unranked Memref base@{{(0x)?[-0-9a-fA-F]*}}
// CHECK-SAME: rank = 1 offset = 0 sizes = [2] strides = [28672] data =
// CHECK-NEXT: [3, 3]
Loading