Skip to content

Commit 049beb9

Browse files
committed
[LinalgToXeGPU] Remove redundant linalg.broadcasts
Signed-off-by: dchigarev <[email protected]>
1 parent fb51bb4 commit 049beb9

File tree

4 files changed

+315
-0
lines changed

4 files changed

+315
-0
lines changed

lib/gc/Transforms/GPU/LinalgToXeGPU.cpp

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1312,6 +1312,106 @@ static FailureOr<Value> findAndReplaceTranspose(const Value &matmulOperand,
13121312
defOp, "No transpose operation producing the operand was found");
13131313
}
13141314

1315+
// Checks whether the given linalgOp operand is produced by a
1316+
// `linalg::BroadcastOp` that can be replaced by a simple subview
1317+
// (for example broadcast: 7x128 -> 1x1x7x128) and ensures that
1318+
// the broadcast result is only used by linalgOp in question.
1319+
//
1320+
// If a valid `linalg::BroadcastOp` is found, the function removes it
1321+
// and returns the operand of the `linalg::BroadcastOp` as the new
1322+
// linalgOp operand. Otherwise returns the original operand.
1323+
static Value findAndReplaceBroadcast(linalg::LinalgOp linalgOp,
1324+
size_t operandIdx,
1325+
PatternRewriter &rewriter) {
1326+
auto operand = linalgOp.getDpsInputs()[operandIdx];
1327+
auto operandParent = operand;
1328+
1329+
// walk over the 'Value' users and verify that it's only used by 'ops'
1330+
std::function<bool(Value, SmallVector<linalg::LinalgOp> &)> onlyUsedByOp;
1331+
onlyUsedByOp = [&onlyUsedByOp](Value value,
1332+
SmallVector<linalg::LinalgOp> &ops) -> bool {
1333+
bool result = true;
1334+
for (auto user : value.getUsers()) {
1335+
if (auto linalgOpUser = dyn_cast<linalg::LinalgOp>(user))
1336+
result &= std::find(ops.begin(), ops.end(), linalgOpUser) !=
1337+
ops.end(); // linalgOpUser == op;
1338+
else if (isa<memref::DeallocOp>(user))
1339+
continue; // allow deallocs as users
1340+
else if (auto subview = dyn_cast<memref::SubViewOp>(user))
1341+
result &= onlyUsedByOp(subview.getResult(), ops);
1342+
else
1343+
return false;
1344+
}
1345+
return result;
1346+
};
1347+
1348+
linalg::BroadcastOp broadcastOp = nullptr;
1349+
while (auto defOp = operandParent.getDefiningOp()) {
1350+
for (auto x : defOp->getUsers()) {
1351+
if (!isa<linalg::BroadcastOp>(x))
1352+
continue;
1353+
1354+
if (broadcastOp) {
1355+
rewriter.notifyMatchFailure(broadcastOp,
1356+
"Only one broadcast operation is allowed");
1357+
return operand;
1358+
}
1359+
1360+
broadcastOp = dyn_cast<linalg::BroadcastOp>(x);
1361+
auto broadcastRes = broadcastOp.getDpsInits()[0];
1362+
SmallVector<linalg::LinalgOp> ops({linalgOp, broadcastOp});
1363+
1364+
// verify that there are no other users of the broadcast result
1365+
// other than the linalgOp in question
1366+
if (!onlyUsedByOp(broadcastRes, ops)) {
1367+
rewriter.notifyMatchFailure(
1368+
broadcastOp, "Broadcast result is used by more than one operation");
1369+
return operand;
1370+
}
1371+
break;
1372+
}
1373+
1374+
if (defOp->getOperands().size() == 0)
1375+
break;
1376+
1377+
operandParent = defOp->getOperand(0);
1378+
}
1379+
if (!broadcastOp) {
1380+
rewriter.notifyMatchFailure(
1381+
linalgOp, "No broadcast operation producing the operand was found");
1382+
return operand;
1383+
}
1384+
1385+
auto brInp = broadcastOp.getDpsInputs()[0];
1386+
auto brOut = broadcastOp.getDpsInits()[0];
1387+
1388+
auto inpType = dyn_cast<MemRefType>(brInp.getType());
1389+
auto outType = dyn_cast<MemRefType>(brOut.getType());
1390+
if (!inpType || !outType)
1391+
return operand;
1392+
1393+
auto inpShape = inpType.getShape();
1394+
auto outShape = outType.getShape();
1395+
1396+
if (inpShape.size() < 2) {
1397+
rewriter.notifyMatchFailure(broadcastOp, "Only nD broadcast is supported");
1398+
return operand;
1399+
}
1400+
1401+
if (!utils::canSqueezeDims(inpShape) || !utils::canSqueezeDims(outShape)) {
1402+
rewriter.notifyMatchFailure(broadcastOp,
1403+
"Can't squeeze broadcast operands to 2D");
1404+
return operand;
1405+
}
1406+
1407+
auto res = utils::reduceMemrefDims(rewriter, broadcastOp.getLoc(), brInp);
1408+
if (failed(res))
1409+
return operand;
1410+
1411+
rewriter.eraseOp(broadcastOp);
1412+
return res.value();
1413+
}
1414+
13151415
// Create XeGPU DPAS kernel out of GEMM-like operation.
13161416
static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,
13171417
ArrayRef<int64_t> dpasTile, int kTile,
@@ -1690,7 +1790,9 @@ LogicalResult createEltwiseKernel(linalg::LinalgOp linalgOp,
16901790

16911791
// Create descriptors and load values for all inputs.
16921792
SmallVector<SmallVector<Value>> loadedInputs;
1793+
size_t operandIdx = 0;
16931794
for (auto input : linalgOp.getDpsInputs()) {
1795+
input = findAndReplaceBroadcast(linalgOp, operandIdx, rewriter);
16941796
SmallVector<Value> inputTiles =
16951797
createDescriptorTiles(rewriter, loc, input, outputShape, tileShape);
16961798

@@ -1699,6 +1801,7 @@ LogicalResult createEltwiseKernel(linalg::LinalgOp linalgOp,
16991801
/*vnniConf=*/std::nullopt,
17001802
/*transpose=*/nullptr, /*transpose_bit=*/nullptr);
17011803
loadedInputs.push_back(loadedVals);
1804+
operandIdx++;
17021805
}
17031806

17041807
// Extract SIMD sized sub-tiles from loaded tiles.

lib/gc/Transforms/GPU/Pipeline.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ void populateGPUPipeline(OpPassManager &pm,
3737

3838
pm.addPass(createDecomposeTensorOperation());
3939
pm.addNestedPass<func::FuncOp>(createGpuTilingAndFusion());
40+
pm.addPass(createCanonicalizerPass());
4041

4142
pm.addPass(bufferization::createEmptyTensorEliminationPass());
4243
pm.addPass(bufferization::createEmptyTensorToAllocTensorPass());
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
// RUN: gc-opt %s -linalg-to-xegpu="dpas-tile=8,16,16 k-tile=16" -canonicalize -split-input-file | FileCheck %s
2+
3+
// CHECK-LABEL: func.func @broadcast_eliminate_2d
4+
func.func @broadcast_eliminate_2d() {
5+
%c1 = arith.constant 1 : index
6+
%c2 = arith.constant 2 : index
7+
%c4 = arith.constant 4 : index
8+
9+
// CHECK: %[[MEMREF_0:.*]] = memref.alloc() : memref<7x128xf16>
10+
%0 = memref.alloc() : memref<7x128xf16>
11+
// CHECK: %[[MEMREF_2:.*]] = memref.alloc() : memref<1x1x7x128xf16>
12+
%2 = memref.alloc() : memref<1x1x7x128xf16>
13+
// CHECK: %[[MEMREF_3:.*]] = memref.alloc() : memref<1x1x7x128xf16>
14+
%3 = memref.alloc() : memref<1x1x7x128xf16>
15+
16+
gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg11 = %c2, %arg12 = %c4, %arg13 = %c1) threads(%arg6, %arg7, %arg8) in (%arg14 = %c4, %arg15 = %c1, %arg16 = %c1) {
17+
// CHECK-NOT: memref.alloc() : memref<4x1x7x128xf16, 3>
18+
%slm_base = memref.alloc() : memref<4x1x7x128xf16, 3>
19+
%1 = memref.subview %slm_base[%arg6, 0, 0, 0] [1, 1, 7, 128] [1, 1, 1, 1] : memref<4x1x7x128xf16, 3> to memref<1x1x7x128xf16, strided<[896, 896, 128, 1], offset: ?>, 3>
20+
21+
// CHECK-NOT: linalg.broadcast
22+
linalg.broadcast ins(%0 : memref<7x128xf16>) outs(%1 : memref<1x1x7x128xf16, strided<[896, 896, 128, 1], offset: ?>, 3>) dimensions = [0, 1]
23+
// CHECK: xegpu.create_nd_tdesc %[[MEMREF_0]]
24+
linalg.add ins(%1, %2 : memref<1x1x7x128xf16, strided<[896, 896, 128, 1], offset: ?>, 3>, memref<1x1x7x128xf16>) outs(%3 : memref<1x1x7x128xf16>)
25+
gpu.terminator
26+
}
27+
return
28+
}
29+
30+
// -----
31+
32+
// CHECK-LABEL: func.func @broadcast_eliminate
33+
func.func @broadcast_eliminate_3d() {
34+
%c1 = arith.constant 1 : index
35+
%c2 = arith.constant 2 : index
36+
%c4 = arith.constant 4 : index
37+
38+
// CHECK: %[[MEMREF_0:.*]] = memref.alloc() : memref<1x7x128xf16>
39+
%0 = memref.alloc() : memref<1x7x128xf16>
40+
// CHECK: %[[MEMREF_2:.*]] = memref.alloc() : memref<1x1x7x128xf16>
41+
%2 = memref.alloc() : memref<1x1x7x128xf16>
42+
// CHECK: %[[MEMREF_3:.*]] = memref.alloc() : memref<1x1x7x128xf16>
43+
%3 = memref.alloc() : memref<1x1x7x128xf16>
44+
45+
gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg11 = %c2, %arg12 = %c4, %arg13 = %c1) threads(%arg6, %arg7, %arg8) in (%arg14 = %c4, %arg15 = %c1, %arg16 = %c1) {
46+
// CHECK-NOT: memref.alloc() : memref<4x1x7x128xf16, 3>
47+
%slm_base = memref.alloc() : memref<4x1x7x128xf16, 3>
48+
%1 = memref.subview %slm_base[%arg6, 0, 0, 0] [1, 1, 7, 128] [1, 1, 1, 1] : memref<4x1x7x128xf16, 3> to memref<1x1x7x128xf16, strided<[896, 896, 128, 1], offset: ?>, 3>
49+
50+
// CHECK-NOT: linalg.broadcast
51+
linalg.broadcast ins(%0 : memref<1x7x128xf16>) outs(%1 : memref<1x1x7x128xf16, strided<[896, 896, 128, 1], offset: ?>, 3>) dimensions = [0]
52+
// Squeezing the %0 before passing to 'linalg.add'
53+
// CHECK: %[[MEMREF0_SQUEEZ:.+]] = memref.subview %[[MEMREF_0]][0, 0, 0] [1, 7, 128] [1, 1, 1] :
54+
// CHECK-SAME: memref<1x7x128xf16> to memref<7x128xf16, strided<[128, 1]>>
55+
// CHECK: xegpu.create_nd_tdesc %[[MEMREF0_SQUEEZ]]
56+
linalg.add ins(%1, %2 : memref<1x1x7x128xf16, strided<[896, 896, 128, 1], offset: ?>, 3>, memref<1x1x7x128xf16>) outs(%3 : memref<1x1x7x128xf16>)
57+
gpu.terminator
58+
}
59+
return
60+
}
61+
62+
// -----
63+
64+
// CHECK-LABEL: func.func @complex_broadcast
65+
func.func @complex_broadcast_3d() {
66+
%c1 = arith.constant 1 : index
67+
%c2 = arith.constant 2 : index
68+
%c4 = arith.constant 4 : index
69+
70+
// CHECK: %[[MEMREF_0:.*]] = memref.alloc() : memref<7x128xf16>
71+
%0 = memref.alloc() : memref<7x128xf16>
72+
// CHECK: %[[MEMREF_1:.*]] = memref.alloc() : memref<7x7x128xf16>
73+
%1 = memref.alloc() : memref<7x7x128xf16>
74+
// CHECK: %[[MEMREF_2:.*]] = memref.alloc() : memref<7x7x128xf16>
75+
%2 = memref.alloc() : memref<7x7x128xf16>
76+
// CHECK: %[[MEMREF_3:.*]] = memref.alloc() : memref<7x7x128xf16>
77+
%3 = memref.alloc() : memref<7x7x128xf16>
78+
79+
gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg11 = %c2, %arg12 = %c4, %arg13 = %c1) threads(%arg6, %arg7, %arg8) in (%arg14 = %c4, %arg15 = %c1, %arg16 = %c1) {
80+
// CHECK: linalg.broadcast
81+
linalg.broadcast ins(%0 : memref<7x128xf16>) outs(%1 : memref<7x7x128xf16>) dimensions = [0]
82+
linalg.add ins(%1, %2 : memref<7x7x128xf16>, memref<7x7x128xf16>) outs(%3 : memref<7x7x128xf16>)
83+
gpu.terminator
84+
}
85+
return
86+
}
87+
88+
// -----
89+
90+
// CHECK-LABEL: func.func @single_broadcast
91+
func.func @single_broadcast() {
92+
%c1 = arith.constant 1 : index
93+
%c2 = arith.constant 2 : index
94+
%c4 = arith.constant 4 : index
95+
96+
// CHECK: %[[MEMREF_0:.*]] = memref.alloc() : memref<7x128xf16>
97+
%0 = memref.alloc() : memref<7x128xf16>
98+
// CHECK: %[[MEMREF_1:.*]] = memref.alloc() : memref<1x1x7x128xf16>
99+
%1 = memref.alloc() : memref<1x1x7x128xf16>
100+
101+
gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg11 = %c2, %arg12 = %c4, %arg13 = %c1) threads(%arg6, %arg7, %arg8) in (%arg14 = %c4, %arg15 = %c1, %arg16 = %c1) {
102+
// CHECK: linalg.broadcast
103+
linalg.broadcast ins(%0 : memref<7x128xf16>) outs(%1 : memref<1x1x7x128xf16>) dimensions = [0, 1]
104+
gpu.terminator
105+
}
106+
return
107+
}
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
// RUN: gc-gpu-runner --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils %s | FileCheck %s
2+
3+
!dtype=f16
4+
!input_memref_type=memref<2x7x32x128x!dtype>
5+
!input_tensor_type=tensor<2x7x32x128x!dtype>
6+
!output_memref_type=memref<2x32x7x128x!dtype>
7+
!output_tensor_type=tensor<2x32x7x128x!dtype>
8+
!cos_sin_cache_memref_type=memref<1x1x2048x128x!dtype>
9+
!cos_sin_cache_tensor_type=tensor<1x1x2048x128x!dtype>
10+
!cos_sin_cache_tensor_shrink_type=tensor<1x1x7x128x!dtype>
11+
!pos_ids_memref_type=memref<1x7xindex>
12+
!pos_ids_tensor_type=tensor<1x7xindex>
13+
#map = affine_map<(xi, yi, zi) -> ((xi * 3 * 4 + yi * 4 + zi) * 2)>
14+
module @fragment_name {
15+
memref.global "private" constant @_cos_cache : !cos_sin_cache_memref_type = dense<3.000000e+00>
16+
memref.global "private" constant @_sin_cache : !cos_sin_cache_memref_type = dense<2.000000e+00>
17+
memref.global "private" constant @_iinput_const : !input_memref_type = dense<3.000000e+00>
18+
memref.global "private" constant @_ipos_ids_const : !pos_ids_memref_type = dense<1>
19+
memref.global "private" constant @_ipos_id_end_const : memref<1xindex> = dense<1>
20+
func.func @RoPE(%iinput: !input_memref_type, %ipos_ids: !pos_ids_memref_type, %ipos_id_end: memref<1xindex>, %out: !output_memref_type) {
21+
%input = bufferization.to_tensor %iinput restrict : !input_memref_type
22+
%cos_cache = memref.get_global @_cos_cache : !cos_sin_cache_memref_type
23+
%sin_cache = memref.get_global @_sin_cache : !cos_sin_cache_memref_type
24+
%cos_cache_tensor = bufferization.to_tensor %cos_cache restrict : !cos_sin_cache_memref_type
25+
%sin_cache_tensor = bufferization.to_tensor %sin_cache restrict : !cos_sin_cache_memref_type
26+
%pos_ids = bufferization.to_tensor %ipos_ids restrict : !pos_ids_memref_type
27+
%pos_id_end = bufferization.to_tensor %ipos_id_end restrict : memref<1xindex>
28+
%3 = tensor.empty(): !output_tensor_type
29+
//call @stopTimerMy() : () -> ()
30+
%transpose_in = linalg.transpose ins(%input: !input_tensor_type) outs(%3:!output_tensor_type) permutation = [0, 2, 1, 3]
31+
32+
//call @startTimerMy() : () -> ()
33+
%c0 = arith.constant 0 : index
34+
%c3 = arith.constant 3 : index
35+
%cos_cache_slice = tensor.extract_slice %cos_cache_tensor[0, 0, 0, 0] [1, 1, 7, 128] [1, 1, 1, 1] : !cos_sin_cache_tensor_type to !cos_sin_cache_tensor_shrink_type
36+
%cos_cache_slice2 = tensor.collapse_shape %cos_cache_slice [[0, 1], [2],[3]] : tensor<1x1x7x128x!dtype> into tensor<1x7x128x!dtype>
37+
%cos_cache_slice3 = tensor.collapse_shape %cos_cache_slice2 [[0, 1], [2]] : tensor<1x7x128x!dtype> into tensor<7x128x!dtype>
38+
%pos_ids_index=tensor.expand_shape %pos_ids [[0],[1,2]] output_shape [1, 7, 1] : tensor<1x7xindex> into tensor<1x7x1xindex>
39+
//call @stopTimerMy() : () -> ()
40+
41+
//call @startTimerMy() : () -> ()
42+
43+
%cos_cache_slice4 = tensor.gather %cos_cache_slice3[%pos_ids_index] gather_dims([0]) : (tensor<7x128x!dtype>, tensor<1x7x1xindex>) -> tensor<1x7x128x!dtype>
44+
45+
//call @stopTimerMy() : () -> ()
46+
47+
//call @startTimerMy() : () -> ()
48+
// %cos_cache_slice4 = tensor.expand_shape %cos_cache_slice3[[0,1],[2]] output_shape [1,7,128] : tensor<7x128x!dtype> into tensor<1x7x128x!dtype>
49+
%cos_cache_slice5 = tensor.expand_shape %cos_cache_slice4 [[0,1],[2],[3]] output_shape [1,1,7,128] : tensor<1x7x128x!dtype> into tensor<1x1x7x128x!dtype>
50+
%cos_cache_slice6 = tensor.collapse_shape %cos_cache_slice5 [[0,1,2],[3]] : tensor<1x1x7x128x!dtype> into tensor<7x128x!dtype>
51+
//call @stopTimerMy() : () -> ()
52+
53+
%cos_cache_slice7 = linalg.broadcast ins(%cos_cache_slice6: tensor<7x128x!dtype>) outs(%3: !output_tensor_type) dimensions = [0, 1]
54+
%input_apply_cos_cache = linalg.mul ins(%transpose_in, %cos_cache_slice7: !output_tensor_type, !output_tensor_type) outs(%3: !output_tensor_type) -> !output_tensor_type
55+
56+
%head_dim = tensor.dim %transpose_in, %c3 : !output_tensor_type
57+
%c2 = arith.constant 2 : index
58+
%half_head_dim = arith.floordivsi %head_dim, %c2 : index
59+
%transpose_input_first_half = tensor.extract_slice %transpose_in[0, 0, 0, 0][2, 32, 7, 64][1,1,1,1] : !output_tensor_type to tensor<2x32x7x64x!dtype>
60+
%transpose_input_second_half = tensor.extract_slice %transpose_in[0, 0, 0, %half_head_dim][2, 32, 7, 64][1,1,1,1] : !output_tensor_type to tensor<2x32x7x64x!dtype>
61+
%cnegative1 = arith.constant dense<-1.000000e+00> : tensor<2x32x7x64x!dtype>
62+
%empty_tensor = tensor.empty() : tensor<2x32x7x64x!dtype>
63+
%transpose_input_second_half_opposite = linalg.mul ins(%transpose_input_second_half, %cnegative1: tensor<2x32x7x64x!dtype>, tensor<2x32x7x64x!dtype>) outs(%empty_tensor: tensor<2x32x7x64x!dtype>) -> tensor<2x32x7x64x!dtype>
64+
65+
%transformed_input = tensor.concat dim(3) %transpose_input_second_half_opposite, %transpose_input_first_half : (tensor<2x32x7x64x!dtype>, tensor<2x32x7x64x!dtype>) -> !output_tensor_type
66+
67+
%sin_cache_slice = tensor.extract_slice %sin_cache_tensor[0, 0, 0, 0] [1, 1, 7, 128] [1, 1, 1, 1] : !cos_sin_cache_tensor_type to !cos_sin_cache_tensor_shrink_type
68+
%sin_cache_slice2 = tensor.collapse_shape %sin_cache_slice [[0, 1], [2],[3]] : tensor<1x1x7x128x!dtype> into tensor<1x7x128x!dtype>
69+
%sin_cache_slice3 = tensor.collapse_shape %sin_cache_slice2 [[0, 1], [2]] : tensor<1x7x128x!dtype> into tensor<7x128x!dtype>
70+
%sin_cache_slice4 = tensor.gather %sin_cache_slice3[%pos_ids_index] gather_dims([0]) : (tensor<7x128x!dtype>, tensor<1x7x1xindex>) -> tensor<1x7x128x!dtype>
71+
72+
%sin_cache_slice5 = tensor.expand_shape %sin_cache_slice4 [[0,1],[2],[3]] output_shape [1,1,7,128] : tensor<1x7x128x!dtype> into tensor<1x1x7x128x!dtype>
73+
%sin_cache_slice6 = tensor.collapse_shape %sin_cache_slice5 [[0,1,2],[3]] : tensor<1x1x7x128x!dtype> into tensor<7x128x!dtype>
74+
75+
%sin_cache_slice7 = linalg.broadcast ins(%sin_cache_slice6: tensor<7x128x!dtype>) outs(%3: !output_tensor_type) dimensions = [0, 1]
76+
%input_apply_sin_cache = linalg.mul ins(%transformed_input, %sin_cache_slice7: !output_tensor_type, !output_tensor_type) outs(%3: !output_tensor_type) -> !output_tensor_type
77+
78+
%result = linalg.add ins(%input_apply_cos_cache, %input_apply_sin_cache: !output_tensor_type, !output_tensor_type) outs(%3: !output_tensor_type) -> !output_tensor_type
79+
bufferization.materialize_in_destination %result in restrict writable %out : (!output_tensor_type, !output_memref_type) -> ()
80+
return
81+
}
82+
83+
func.func @main() {
84+
%inp = memref.get_global @_iinput_const : !input_memref_type
85+
%ipos_ids = memref.get_global @_ipos_ids_const : !pos_ids_memref_type
86+
%ipos_id_end = memref.get_global @_ipos_id_end_const : memref<1xindex>
87+
88+
%out = memref.alloc() {alignment = 64 : i64} : !output_memref_type
89+
90+
func.call @RoPE(%inp, %ipos_ids, %ipos_id_end, %out) : (!input_memref_type, !pos_ids_memref_type, memref<1xindex>, !output_memref_type) -> ()
91+
92+
%out_subview = memref.subview %out[0, 0, 0, 0] [2, 1, 1, 1] [1, 1, 1, 1] : !output_memref_type to memref<2xf16, strided<[28672]>>
93+
%cast = memref.cast %out_subview : memref<2xf16, strided<[28672]>> to memref<*xf16>
94+
call @printMemrefF16(%cast) : (memref<*xf16>) -> ()
95+
96+
return
97+
}
98+
99+
func.func private @printMemrefF16(%ptr : memref<*xf16>)
100+
}
101+
102+
// CHECK: Unranked Memref base@{{(0x)?[-0-9a-fA-F]*}}
103+
// CHECK-SAME: rank = 1 offset = 0 sizes = [2] strides = [28672] data =
104+
// CHECK-NEXT: [3, 3]

0 commit comments

Comments
 (0)