Skip to content

Commit 739bb74

Browse files
inho9606memfrob
authored andcommitted
Added static verification for Linalg Ops.
This verification is to check if the indices for static shaped operands on linalgOps access out of bound memory or not. For dynamic shaped operands, we would be able to check it on runtime stage. Found several invalid Linalg ops testcases, and fixed them. Reviewed By: hanchung Differential Revision: https://reviews.llvm.org/D98390
1 parent c9f9d72 commit 739bb74

File tree

9 files changed

+113
-44
lines changed

9 files changed

+113
-44
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,5 +433,54 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
433433
++idx;
434434
}
435435

436+
// Check if given shapes match to inferred shapes.
437+
Optional<SmallVector<int64_t, 4>> loopRanges = linalgOp.getStaticLoopRanges();
438+
if (!loopRanges)
439+
return linalgOp.emitError("unable to find loop range for operation");
440+
441+
// Verify only static cases since we can't get exact dimension sizes and loop
442+
// ranges for dynamic cases in this stage.
443+
if (llvm::none_of(*loopRanges, [](int64_t &range) {
444+
return range == ShapedType::kDynamicSize;
445+
})) {
446+
for (int64_t &range : *loopRanges)
447+
range -= 1;
448+
for (const auto &en : llvm::enumerate(linalgOp.getShapedOperandTypes())) {
449+
auto indices = indexingMaps[en.index()].compose(*loopRanges);
450+
for (auto j : llvm::seq<unsigned>(0, en.value().getRank())) {
451+
452+
// Ignore dynamic dimension or the case that the inferred last index is
453+
// zero. The index is increasing or decreasing in Linalg, for example,
454+
// the last index should be `0` or `size-1`. We only check the cases
455+
// that are non-zero because most of cases are increasing and it is too
456+
// expensive to find the shape of decreasing cases.
457+
if (en.value().isDynamicDim(j) || indices[j] == 0)
458+
continue;
459+
460+
// The size of shaped operands and inferred dimension size should be
461+
// same. But, for now we check if the inferred sizes are in boundary of
462+
// shaped operands' size or not in case that Affine Expressions are
463+
// complicated such as d0 * 3 + d1 since it is not easy to handle the
464+
// issues.
465+
auto inferredSize = indices[j] + 1;
466+
auto shapedDimSize = en.value().getDimSize(j);
467+
if (indexingMaps[en.index()].getResult(j).dyn_cast<AffineDimExpr>()) {
468+
if (inferredSize != shapedDimSize) {
469+
return linalgOp.emitOpError("inferred shaped operand #")
470+
<< en.index() << " has shape's dimension #" << j << " to be "
471+
<< inferredSize << ", but found " << shapedDimSize;
472+
}
473+
} else {
474+
if (inferredSize > shapedDimSize) {
475+
return linalgOp.emitOpError("inferred shaped operand #")
476+
<< en.index() << " has shape's dimension #" << j
477+
<< " to be greater than or equal to " << inferredSize
478+
<< ", but found " << shapedDimSize;
479+
}
480+
}
481+
}
482+
}
483+
}
484+
436485
return success();
437486
}

mlir/test/Dialect/Linalg/fusion-2-level.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ func @f1(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>, %B: memref<?x?xf32, of
2828
scf.for %arg10 = %c0 to %10 step %c4 {
2929
%14 = memref.subview %5[%arg8, %arg10][%c2, %c4][%c1, %c1] : memref<?x?xf32, offset: ?, strides: [?, ?]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
3030
%16 = memref.subview %7[%arg10, %arg9][%c4, %c3][%c1, %c1]: memref<?x?xf32, offset: ?, strides: [?, ?]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
31-
%17 = memref.subview %8[%arg8, %arg9][%c2, %c4][%c1, %c1] : memref<?x?xf32, offset: ?, strides: [?, ?]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
31+
%17 = memref.subview %8[%arg8, %arg9][%c2, %c3][%c1, %c1] : memref<?x?xf32, offset: ?, strides: [?, ?]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
3232
linalg.matmul ins(%14, %16: memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>)
3333
outs(%17: memref<?x?xf32, offset: ?, strides: [?, ?]>)
3434
}

mlir/test/Dialect/Linalg/generalize-named-ops.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// RUN: mlir-opt %s -split-input-file -linalg-generalize-named-ops | FileCheck %s
22

3-
func @generalize_conv(%input : memref<1x225x225x3xf32>, %filter: memref<3x3x3x32xf32>, %output: memref<1x112x112x32xf32>) {
4-
linalg.conv(%filter, %input, %output) {dilations = [2, 3], strides = [4, 5]} : memref<3x3x3x32xf32>, memref<1x225x225x3xf32>, memref<1x112x112x32xf32>
3+
func @generalize_conv(%input : memref<1x449x562x3xf32>, %filter: memref<3x3x3x32xf32>, %output: memref<1x112x112x32xf32>) {
4+
linalg.conv(%filter, %input, %output) {dilations = [2, 3], strides = [4, 5]} : memref<3x3x3x32xf32>, memref<1x449x562x3xf32>, memref<1x112x112x32xf32>
55
return
66
}
77

@@ -10,7 +10,7 @@ func @generalize_conv(%input : memref<1x225x225x3xf32>, %filter: memref<3x3x3x32
1010
// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
1111

1212
// CHECK: func @generalize_conv
13-
// CHECK-SAME: %[[INPUT:.+]]: memref<1x225x225x3xf32>
13+
// CHECK-SAME: %[[INPUT:.+]]: memref<1x449x562x3xf32>
1414
// CHECK-SAME: %[[FILTER:.+]]: memref<3x3x3x32xf32>
1515
// CHECK-SAME: %[[OUTPUT:.+]]: memref<1x112x112x32xf32>
1616

mlir/test/Dialect/Linalg/invalid.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -703,3 +703,23 @@ func @illegal_fill_tensor_with_memref_return
703703
%0 = linalg.fill(%arg0, %arg1) : tensor<?x?xf32>, f32 -> memref<?x?xf32>
704704
return %0 : memref<?x?xf32>
705705
}
706+
707+
// -----
708+
709+
func @invalid_static_matmul(%arg0: memref<2x4xf32>, %arg1: memref<3x4xf32>, %arg2: memref<2x4xf32>) {
710+
// expected-error @+1 {{inferred shaped operand #1 has shape's dimension #0 to be 4, but found 3}}
711+
linalg.matmul ins(%arg0, %arg1 : memref<2x4xf32>, memref<3x4xf32>)
712+
outs(%arg2 :memref<2x4xf32>)
713+
return
714+
}
715+
716+
// -----
717+
718+
func @invalid_static_2d_conv(%input : memref<1x3x4x2xf32>, %filter: memref<3x2x2x1xf32>, %output: memref<1x2x3x1xf32>) {
719+
// expected-error @+1 {{inferred shaped operand #0 has shape's dimension #1 to be greater than or equal to 4, but found 3}}
720+
linalg.conv_2d_input_nhwc_filter_hwcf
721+
{ dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
722+
ins(%input, %filter : memref<1x3x4x2xf32>, memref<3x2x2x1xf32>)
723+
outs(%output : memref<1x2x3x1xf32>)
724+
return
725+
}

mlir/test/Dialect/Linalg/named-ops.mlir

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -282,15 +282,15 @@ func @conv_3d_input_ncdhw_filter_dhwcf(%input: memref<?x?x?x?x?xf32>, %filter: m
282282
// CHECK: %{{.+}} = linalg.pooling_nhwc_sum
283283
// CHECK-SAME: dilations = dense<1> : tensor<2xi64>
284284
// CHECK-SAME: strides = dense<1> : tensor<2xi64>
285-
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x6x6x1xf32>, tensor<3x3xf32>)
285+
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xf32>, tensor<3x3xf32>)
286286
// CHECK-SAME: outs(%{{.+}} : tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32>
287-
func @pooling_nhwc_sum_tensor(%input: tensor<1x6x6x1xf32>) -> tensor<1x2x2x1xf32> {
287+
func @pooling_nhwc_sum_tensor(%input: tensor<1x4x4x1xf32>) -> tensor<1x2x2x1xf32> {
288288
%fake = linalg.init_tensor [3, 3] : tensor<3x3xf32>
289289
%init = linalg.init_tensor [1, 2, 2, 1] : tensor<1x2x2x1xf32>
290290
%cst = constant 0.000000e+00 : f32
291291
%fill = linalg.fill(%init, %cst) : tensor<1x2x2x1xf32>, f32 -> tensor<1x2x2x1xf32>
292292
%res = linalg.pooling_nhwc_sum {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
293-
ins(%input, %fake: tensor<1x6x6x1xf32>, tensor<3x3xf32>)
293+
ins(%input, %fake: tensor<1x4x4x1xf32>, tensor<3x3xf32>)
294294
outs(%fill: tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32>
295295
return %res : tensor<1x2x2x1xf32>
296296
}
@@ -301,11 +301,11 @@ func @pooling_nhwc_sum_tensor(%input: tensor<1x6x6x1xf32>) -> tensor<1x2x2x1xf32
301301
// CHECK: linalg.pooling_nhwc_sum
302302
// CHECK-SAME: dilations = dense<1> : tensor<2xi64>
303303
// CHECK-SAME: strides = dense<1> : tensor<2xi64>
304-
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x6x6x1xf32>, memref<3x3xf32>)
304+
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x4x4x1xf32>, memref<3x3xf32>)
305305
// CHECK-SAME: outs(%{{.+}} : memref<1x2x2x1xf32>)
306-
func @pooling_nhwc_sum(%input: memref<1x6x6x1xf32>, %fake: memref<3x3xf32>, %output: memref<1x2x2x1xf32>) {
306+
func @pooling_nhwc_sum(%input: memref<1x4x4x1xf32>, %fake: memref<3x3xf32>, %output: memref<1x2x2x1xf32>) {
307307
linalg.pooling_nhwc_sum {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
308-
ins(%input, %fake: memref<1x6x6x1xf32>, memref<3x3xf32>)
308+
ins(%input, %fake: memref<1x4x4x1xf32>, memref<3x3xf32>)
309309
outs(%output: memref<1x2x2x1xf32>)
310310
return
311311
}
@@ -316,15 +316,15 @@ func @pooling_nhwc_sum(%input: memref<1x6x6x1xf32>, %fake: memref<3x3xf32>, %out
316316
// CHECK: %{{.+}} = linalg.pooling_nhwc_max
317317
// CHECK-SAME: dilations = dense<1> : tensor<2xi64>
318318
// CHECK-SAME: strides = dense<1> : tensor<2xi64>
319-
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x6x6x1xf32>, tensor<3x3xf32>)
319+
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xf32>, tensor<3x3xf32>)
320320
// CHECK-SAME: outs(%{{.+}} : tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32>
321-
func @pooling_nhwc_max_tensor(%input: tensor<1x6x6x1xf32>) -> tensor<1x2x2x1xf32> {
321+
func @pooling_nhwc_max_tensor(%input: tensor<1x4x4x1xf32>) -> tensor<1x2x2x1xf32> {
322322
%fake = linalg.init_tensor [3, 3] : tensor<3x3xf32>
323323
%init = linalg.init_tensor [1, 2, 2, 1] : tensor<1x2x2x1xf32>
324324
%cst = constant 0.000000e+00 : f32
325325
%fill = linalg.fill(%init, %cst) : tensor<1x2x2x1xf32>, f32 -> tensor<1x2x2x1xf32>
326326
%res = linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
327-
ins(%input, %fake: tensor<1x6x6x1xf32>, tensor<3x3xf32>)
327+
ins(%input, %fake: tensor<1x4x4x1xf32>, tensor<3x3xf32>)
328328
outs(%fill: tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32>
329329
return %res : tensor<1x2x2x1xf32>
330330
}
@@ -335,11 +335,11 @@ func @pooling_nhwc_max_tensor(%input: tensor<1x6x6x1xf32>) -> tensor<1x2x2x1xf32
335335
// CHECK: linalg.pooling_nhwc_max
336336
// CHECK-SAME: dilations = dense<1> : tensor<2xi64>
337337
// CHECK-SAME: strides = dense<1> : tensor<2xi64>
338-
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x6x6x1xf32>, memref<3x3xf32>)
338+
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x4x4x1xf32>, memref<3x3xf32>)
339339
// CHECK-SAME: outs(%{{.+}} : memref<1x2x2x1xf32>)
340-
func @pooling_nhwc_max(%input: memref<1x6x6x1xf32>, %fake: memref<3x3xf32>, %output: memref<1x2x2x1xf32>) {
340+
func @pooling_nhwc_max(%input: memref<1x4x4x1xf32>, %fake: memref<3x3xf32>, %output: memref<1x2x2x1xf32>) {
341341
linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
342-
ins(%input, %fake: memref<1x6x6x1xf32>, memref<3x3xf32>)
342+
ins(%input, %fake: memref<1x4x4x1xf32>, memref<3x3xf32>)
343343
outs(%output: memref<1x2x2x1xf32>)
344344
return
345345
}
@@ -350,15 +350,15 @@ func @pooling_nhwc_max(%input: memref<1x6x6x1xf32>, %fake: memref<3x3xf32>, %out
350350
// CHECK: %{{.+}} = linalg.pooling_nhwc_min
351351
// CHECK-SAME: dilations = dense<1> : tensor<2xi64>
352352
// CHECK-SAME: strides = dense<1> : tensor<2xi64>
353-
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x6x6x1xf32>, tensor<3x3xf32>)
353+
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xf32>, tensor<3x3xf32>)
354354
// CHECK-SAME: outs(%{{.+}} : tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32>
355-
func @pooling_nhwc_min_tensor(%input: tensor<1x6x6x1xf32>) -> tensor<1x2x2x1xf32> {
355+
func @pooling_nhwc_min_tensor(%input: tensor<1x4x4x1xf32>) -> tensor<1x2x2x1xf32> {
356356
%fake = linalg.init_tensor [3, 3] : tensor<3x3xf32>
357357
%init = linalg.init_tensor [1, 2, 2, 1] : tensor<1x2x2x1xf32>
358358
%cst = constant 0.000000e+00 : f32
359359
%fill = linalg.fill(%init, %cst) : tensor<1x2x2x1xf32>, f32 -> tensor<1x2x2x1xf32>
360360
%res = linalg.pooling_nhwc_min {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
361-
ins(%input, %fake: tensor<1x6x6x1xf32>, tensor<3x3xf32>)
361+
ins(%input, %fake: tensor<1x4x4x1xf32>, tensor<3x3xf32>)
362362
outs(%fill: tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32>
363363
return %res : tensor<1x2x2x1xf32>
364364
}
@@ -369,11 +369,11 @@ func @pooling_nhwc_min_tensor(%input: tensor<1x6x6x1xf32>) -> tensor<1x2x2x1xf32
369369
// CHECK: linalg.pooling_nhwc_min
370370
// CHECK-SAME: dilations = dense<1> : tensor<2xi64>
371371
// CHECK-SAME: strides = dense<1> : tensor<2xi64>
372-
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x6x6x1xf32>, memref<3x3xf32>)
372+
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x4x4x1xf32>, memref<3x3xf32>)
373373
// CHECK-SAME: outs(%{{.+}} : memref<1x2x2x1xf32>)
374-
func @pooling_nhwc_min(%input: memref<1x6x6x1xf32>, %fake: memref<3x3xf32>, %output: memref<1x2x2x1xf32>) {
374+
func @pooling_nhwc_min(%input: memref<1x4x4x1xf32>, %fake: memref<3x3xf32>, %output: memref<1x2x2x1xf32>) {
375375
linalg.pooling_nhwc_min {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
376-
ins(%input, %fake: memref<1x6x6x1xf32>, memref<3x3xf32>)
376+
ins(%input, %fake: memref<1x4x4x1xf32>, memref<3x3xf32>)
377377
outs(%output: memref<1x2x2x1xf32>)
378378
return
379379
}

mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,9 @@ func @generic_op_021_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf3
168168

169169
#map0 = affine_map<(d0, d1, d2) -> (d0)>
170170
#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
171-
#map2 = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
172-
#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
173-
func @generic_op_120_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<5x7x3xf32> {
171+
#map2 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
172+
#map3 = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
173+
func @generic_op_120_permutation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<5x7x3xf32> {
174174
%0 = linalg.tensor_reshape %arg0 [#map0, #map1] : tensor<3x35xf32> into tensor<3x5x7xf32>
175175
%1 = linalg.init_tensor [5, 7, 3] : tensor<5x7x3xf32>
176176
%2 = linalg.generic
@@ -183,9 +183,9 @@ func @generic_op_120_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf3
183183
return %2 : tensor<5x7x3xf32>
184184
}
185185

186-
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0 * 7 + d1)>
187-
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
188-
// CHECK: func @generic_op_120_permultation_reshape_producer_fusion
186+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)>
187+
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
188+
// CHECK: func @generic_op_120_permutation_reshape_producer_fusion
189189
// CHECK-NOT: linalg.tensor_reshape
190190
// CHECK: linalg.generic
191191
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]

mlir/test/Dialect/Linalg/sparse_nd.mlir

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
// CHECK-LABEL: func @mul(
2323
// CHECK-SAME: %[[VAL_0:.*0]]: tensor<10x20x30x40x50x60x70x80xf32>,
24-
// CHECK-SAME: %[[VAL_1:.*1]]: tensor<10x20x30x40x50x60x70x80xf32>,
24+
// CHECK-SAME: %[[VAL_1:.*1]]: tensor<80x70x60x50x40x30x20x10xf32>,
2525
// CHECK-SAME: %[[VAL_2:.*2]]: tensor<10x20x30x40x50x60x70x80xf32>) -> tensor<10x20x30x40x50x60x70x80xf32> {
2626
// CHECK: %[[VAL_3:.*]] = constant 3 : index
2727
// CHECK: %[[VAL_4:.*]] = constant 4 : index
@@ -34,11 +34,11 @@
3434
// CHECK: %[[VAL_11:.*]] = constant 0 : index
3535
// CHECK: %[[VAL_12:.*]] = constant 1 : index
3636
// CHECK: %[[VAL_13:.*]] = memref.buffer_cast %[[VAL_0]] : memref<10x20x30x40x50x60x70x80xf32>
37-
// CHECK: %[[VAL_14:.*]] = linalg.sparse_pointers %[[VAL_1]], %[[VAL_3]] : tensor<10x20x30x40x50x60x70x80xf32> to memref<?xindex>
38-
// CHECK: %[[VAL_15:.*]] = linalg.sparse_indices %[[VAL_1]], %[[VAL_3]] : tensor<10x20x30x40x50x60x70x80xf32> to memref<?xindex>
39-
// CHECK: %[[VAL_16:.*]] = linalg.sparse_pointers %[[VAL_1]], %[[VAL_4]] : tensor<10x20x30x40x50x60x70x80xf32> to memref<?xindex>
40-
// CHECK: %[[VAL_17:.*]] = linalg.sparse_indices %[[VAL_1]], %[[VAL_4]] : tensor<10x20x30x40x50x60x70x80xf32> to memref<?xindex>
41-
// CHECK: %[[VAL_18:.*]] = linalg.sparse_values %[[VAL_1]] : tensor<10x20x30x40x50x60x70x80xf32> to memref<?xf32>
37+
// CHECK: %[[VAL_14:.*]] = linalg.sparse_pointers %[[VAL_1]], %[[VAL_3]] : tensor<80x70x60x50x40x30x20x10xf32> to memref<?xindex>
38+
// CHECK: %[[VAL_15:.*]] = linalg.sparse_indices %[[VAL_1]], %[[VAL_3]] : tensor<80x70x60x50x40x30x20x10xf32> to memref<?xindex>
39+
// CHECK: %[[VAL_16:.*]] = linalg.sparse_pointers %[[VAL_1]], %[[VAL_4]] : tensor<80x70x60x50x40x30x20x10xf32> to memref<?xindex>
40+
// CHECK: %[[VAL_17:.*]] = linalg.sparse_indices %[[VAL_1]], %[[VAL_4]] : tensor<80x70x60x50x40x30x20x10xf32> to memref<?xindex>
41+
// CHECK: %[[VAL_18:.*]] = linalg.sparse_values %[[VAL_1]] : tensor<80x70x60x50x40x30x20x10xf32> to memref<?xf32>
4242
// CHECK: %[[VAL_19:.*]] = memref.buffer_cast %[[VAL_2]] : memref<10x20x30x40x50x60x70x80xf32>
4343
// CHECK: %[[VAL_20:.*]] = memref.alloc() : memref<10x20x30x40x50x60x70x80xf32>
4444
// CHECK: linalg.copy(%[[VAL_19]], %[[VAL_20]]) : memref<10x20x30x40x50x60x70x80xf32>, memref<10x20x30x40x50x60x70x80xf32>
@@ -84,12 +84,12 @@
8484
// CHECK: return %[[VAL_50]] : tensor<10x20x30x40x50x60x70x80xf32>
8585
// CHECK: }
8686
func @mul(%arga: tensor<10x20x30x40x50x60x70x80xf32>,
87-
%argb: tensor<10x20x30x40x50x60x70x80xf32>,
87+
%argb: tensor<80x70x60x50x40x30x20x10xf32>,
8888
%argx: tensor<10x20x30x40x50x60x70x80xf32>)
8989
-> tensor<10x20x30x40x50x60x70x80xf32> {
9090
%0 = linalg.generic #trait_mul
9191
ins(%arga, %argb: tensor<10x20x30x40x50x60x70x80xf32>,
92-
tensor<10x20x30x40x50x60x70x80xf32>)
92+
tensor<80x70x60x50x40x30x20x10xf32>)
9393
outs(%argx: tensor<10x20x30x40x50x60x70x80xf32>) {
9494
^bb(%a: f32, %b: f32, %x: f32):
9595
%0 = mulf %a, %b : f32

0 commit comments

Comments
 (0)