@@ -366,14 +366,10 @@ func.func @tensor.insert(%t1: tensor<5xf32>, %idx1: index, %f: f32) -> tensor<5x
366366// -----
367367
368368// CHECK-LABEL: func @tensor.expand_shape(
369- // CHECK-SAME: %[[t1:.*]]: tensor<?x10xf32>
369+ // CHECK-SAME: %[[t1:.*]]: tensor<?x10xf32>, %[[sz0:.*]]: index
370370func.func @tensor.expand_shape (%t1: tensor <?x10 xf32 >, %sz0: index ) -> tensor <2 x?x10 xf32 > {
371371 // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]]
372- // CHECK: %[[C0:.*]] = arith.constant 0 : index
373- // CHECK: %[[DIM:.*]] = memref.dim %[[m1]], %[[C0]] : memref<?x10xf32>
374- // CHECK: %[[C2:.*]] = arith.constant 2 : index
375- // CHECK: %[[VAL_1:.*]] = arith.divsi %[[DIM]], %[[C2]] : index
376- // CHECK: %[[expanded:.*]] = memref.expand_shape %[[m1]] {{\[\[}}0, 1], [2]] output_shape [2, %[[VAL_1]], 10] : memref<?x10xf32> into memref<2x?x10xf32>
372+ // CHECK: %[[expanded:.*]] = memref.expand_shape %[[m1]] {{\[\[}}0, 1], [2]] output_shape [2, %[[sz0]], 10] : memref<?x10xf32> into memref<2x?x10xf32>
377373 %0 = tensor.expand_shape %t1 [[0 , 1 ], [2 ]] output_shape [2 , %sz0 , 10 ]
378374 : tensor <?x10 xf32 > into tensor <2 x?x10 xf32 >
379375
@@ -385,23 +381,20 @@ func.func @tensor.expand_shape(%t1: tensor<?x10xf32>, %sz0: index) -> tensor<2x?
385381// -----
386382
387383// CHECK-LABEL: func @tensor.expand_shape_of_slice(
388- // CHECK-SAME: %[[t1:.*]]: tensor<?x20xf32>
384+ // CHECK-SAME: %[[t1:.*]]: tensor<?x20xf32>, %{{.*}}: index, %{{.*}}: index, %[[sz0:.*]]: index
389385func.func @tensor.expand_shape_of_slice (
390386 %t1: tensor <?x20 xf32 >, %o1: index , %s1: index , %sz0: index ) -> tensor <?x7 x2 x5 xf32 > {
391387 // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] :
392388 // CHECK: %[[subview:.*]] = memref.subview %[[m1]][%{{.*}}, 5] [%{{.*}}, 10] [1, 1] : memref<?x20xf32> to memref<?x10xf32, strided<[20, 1], offset: ?>>
393389 %0 = tensor.extract_slice %t1 [%o1 , 5 ][%s1 , 10 ][1 , 1 ] :
394390 tensor <?x20 xf32 > to tensor <?x10 xf32 >
395- // CHECK: %[[C7:.*]] = arith.constant 7 : index
396- // CHECK: %[[VAL_1:.*]] = arith.divsi %{{.*}}, %[[C7]] : index
397- // CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_1]], 7, 2, 5] : memref<?x10xf32, strided<[20, 1], offset: ?>> into memref<?x7x2x5xf32, strided<[140, 20, 5, 1], offset: ?>>
391+ // CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[sz0]], 7, 2, 5] : memref<?x10xf32, strided<[20, 1], offset: ?>> into memref<?x7x2x5xf32, strided<[140, 20, 5, 1], offset: ?>>
398392 %1 = tensor.expand_shape %0 [[0 , 1 ], [2 , 3 ]] output_shape [%sz0 , 7 , 2 , 5 ] :
399393 tensor <?x10 xf32 > into tensor <?x7 x2 x5 xf32 >
400394 // CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]]
401395 // CHECK: return %[[r]]
402396 return %1 : tensor <?x7 x2 x5 xf32 >
403397}
404-
405398// -----
406399
407400// CHECK-LABEL: func @tensor.expand_shape_of_scalar_slice(
@@ -417,7 +410,20 @@ func.func @tensor.expand_shape_of_scalar_slice(
417410 // CHECK: return %[[r]]
418411 return %1 : tensor <1 xf32 >
419412}
413+ // -----
420414
415+ // CHECK-LABEL: func @tensor.expand_shape_multiple_dynamic_indices(
416+ // CHECK-SAME: %[[t1:.*]]: tensor<?x256xf32>, %[[sz0:.*]]: index, %[[sz1:.*]]: index, %[[sz2:.*]]: index
417+ func.func @tensor.expand_shape_multiple_dynamic_indices (%t1: tensor <?x256 xf32 >, %sz0: index , %sz1: index , %sz2: index ) -> tensor <?x?x?x256 xf32 > {
418+ // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]]
419+ // CHECK: %[[expanded:.*]] = memref.expand_shape %[[m1]] {{\[\[}}0, 1, 2], [3]] output_shape [%[[sz0]], %[[sz1]], %[[sz2]], 256] : memref<?x256xf32> into memref<?x?x?x256xf32>
420+ %0 = tensor.expand_shape %t1 [[0 , 1 , 2 ], [3 ]] output_shape [%sz0 , %sz1 , %sz2 , 256 ]
421+ : tensor <?x256 xf32 > into tensor <?x?x?x256 xf32 >
422+
423+ // CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]]
424+ // CHECK: return %[[r]]
425+ return %0 : tensor <?x?x?x256 xf32 >
426+ }
421427// -----
422428
423429// CHECK-LABEL: func @tensor.collapse_shape(
@@ -646,3 +652,6 @@ func.func @parallel_insert_slice_copy_before_write(%in: tensor<4xf32>, %out: ten
646652 // CHECK: }
647653 return
648654}
655+
656+ // -----
657+
0 commit comments