@@ -344,3 +344,117 @@ util.func @collapse_of_expand_preserved_trailing_unit_dims(%arg0: tensor<1x23040
344344// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[EXPAND]]
345345// CHECK-SAME: tensor<1x4x5760x1xbf16> into tensor<4x5760x1xbf16>
346346// CHECK: util.return %[[COLLAPSE]] : tensor<4x5760x1xbf16>
347+
348+ // -----
349+
350+ util.func @fold_unit_dims_from_extract_leading (%arg0: tensor <1 x4 x8 xf32 >, %idx0: index , %idx1: index , %idx2: index ) -> f32 {
351+ %extracted = tensor.extract %arg0 [%idx0 , %idx1 , %idx2 ] : tensor <1 x4 x8 xf32 >
352+ util.return %extracted : f32
353+ }
354+ // CHECK-LABEL: util.func public @fold_unit_dims_from_extract_leading
355+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<1x4x8xf32>
356+ // CHECK-SAME: %[[IDX0:[a-zA-Z0-9]+]]: index
357+ // CHECK-SAME: %[[IDX1:[a-zA-Z0-9]+]]: index
358+ // CHECK-SAME: %[[IDX2:[a-zA-Z0-9]+]]: index
359+ // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1], [2]{{\]}}
360+ // CHECK-SAME: tensor<1x4x8xf32> into tensor<4x8xf32>
361+ // CHECK: %[[EXTRACT:.+]] = tensor.extract %[[COLLAPSED]][%[[IDX1]], %[[IDX2]]]
362+ // CHECK: util.return %[[EXTRACT]] : f32
363+
364+ // -----
365+
366+ util.func @fold_unit_dims_from_extract_trailing (%arg0: tensor <4 x8 x1 xf32 >, %idx0: index , %idx1: index , %idx2: index ) -> f32 {
367+ %extracted = tensor.extract %arg0 [%idx0 , %idx1 , %idx2 ] : tensor <4 x8 x1 xf32 >
368+ util.return %extracted : f32
369+ }
370+ // CHECK-LABEL: util.func public @fold_unit_dims_from_extract_trailing
371+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<4x8x1xf32>
372+ // CHECK-SAME: %[[IDX0:[a-zA-Z0-9]+]]: index
373+ // CHECK-SAME: %[[IDX1:[a-zA-Z0-9]+]]: index
374+ // CHECK-SAME: %[[IDX2:[a-zA-Z0-9]+]]: index
375+ // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2]{{\]}}
376+ // CHECK-SAME: tensor<4x8x1xf32> into tensor<4x8xf32>
377+ // CHECK: %[[EXTRACT:.+]] = tensor.extract %[[COLLAPSED]][%[[IDX0]], %[[IDX1]]]
378+ // CHECK: util.return %[[EXTRACT]] : f32
379+
380+ // -----
381+
382+ util.func @fold_unit_dims_from_extract_middle (%arg0: tensor <4 x1 x8 xf32 >, %idx0: index , %idx1: index , %idx2: index ) -> f32 {
383+ %extracted = tensor.extract %arg0 [%idx0 , %idx1 , %idx2 ] : tensor <4 x1 x8 xf32 >
384+ util.return %extracted : f32
385+ }
386+ // CHECK-LABEL: util.func public @fold_unit_dims_from_extract_middle
387+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<4x1x8xf32>
388+ // CHECK-SAME: %[[IDX0:[a-zA-Z0-9]+]]: index
389+ // CHECK-SAME: %[[IDX1:[a-zA-Z0-9]+]]: index
390+ // CHECK-SAME: %[[IDX2:[a-zA-Z0-9]+]]: index
391+ // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2]{{\]}}
392+ // CHECK-SAME: tensor<4x1x8xf32> into tensor<4x8xf32>
393+ // CHECK: %[[EXTRACT:.+]] = tensor.extract %[[COLLAPSED]][%[[IDX0]], %[[IDX2]]]
394+ // CHECK: util.return %[[EXTRACT]] : f32
395+
396+ // -----
397+
398+ util.func @fold_unit_dims_from_extract_multiple (%arg0: tensor <1 x4 x1 x8 x1 xf32 >, %idx0: index , %idx1: index , %idx2: index , %idx3: index , %idx4: index ) -> f32 {
399+ %extracted = tensor.extract %arg0 [%idx0 , %idx1 , %idx2 , %idx3 , %idx4 ] : tensor <1 x4 x1 x8 x1 xf32 >
400+ util.return %extracted : f32
401+ }
402+ // CHECK-LABEL: util.func public @fold_unit_dims_from_extract_multiple
403+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<1x4x1x8x1xf32>
404+ // CHECK-SAME: %[[IDX0:[a-zA-Z0-9]+]]: index
405+ // CHECK-SAME: %[[IDX1:[a-zA-Z0-9]+]]: index
406+ // CHECK-SAME: %[[IDX2:[a-zA-Z0-9]+]]: index
407+ // CHECK-SAME: %[[IDX3:[a-zA-Z0-9]+]]: index
408+ // CHECK-SAME: %[[IDX4:[a-zA-Z0-9]+]]: index
409+ // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1], [2, 3, 4]{{\]}}
410+ // CHECK-SAME: tensor<1x4x1x8x1xf32> into tensor<4x8xf32>
411+ // CHECK: %[[EXTRACT:.+]] = tensor.extract %[[COLLAPSED]]
412+ // CHECK: util.return %[[EXTRACT]] : f32
413+
414+ // -----
415+
416+ // Test folding consecutive unit dims from tensor.extract
417+ util.func @fold_unit_dims_from_extract_consecutive (%arg0: tensor <1 x1 x1 x8 xf32 >, %idx0: index , %idx1: index , %idx2: index , %idx3: index ) -> f32 {
418+ %extracted = tensor.extract %arg0 [%idx0 , %idx1 , %idx2 , %idx3 ] : tensor <1 x1 x1 x8 xf32 >
419+ util.return %extracted : f32
420+ }
421+ // CHECK-LABEL: util.func public @fold_unit_dims_from_extract_consecutive
422+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<1x1x1x8xf32>
423+ // CHECK-SAME: %[[IDX0:[a-zA-Z0-9]+]]: index
424+ // CHECK-SAME: %[[IDX1:[a-zA-Z0-9]+]]: index
425+ // CHECK-SAME: %[[IDX2:[a-zA-Z0-9]+]]: index
426+ // CHECK-SAME: %[[IDX3:[a-zA-Z0-9]+]]: index
427+ // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2, 3]{{\]}}
428+ // CHECK-SAME: tensor<1x1x1x8xf32> into tensor<8xf32>
429+ // CHECK: %[[EXTRACT:.+]] = tensor.extract %[[COLLAPSED]][%[[IDX3]]]
430+ // CHECK: util.return %[[EXTRACT]] : f32
431+
432+ // -----
433+
434+ // Test folding unit dims with dynamic dimensions
435+ util.func @fold_unit_dims_from_extract_dynamic (%arg0: tensor <1 x?x1 xf32 >, %idx0: index , %idx1: index , %idx2: index ) -> f32 {
436+ %extracted = tensor.extract %arg0 [%idx0 , %idx1 , %idx2 ] : tensor <1 x?x1 xf32 >
437+ util.return %extracted : f32
438+ }
439+ // CHECK-LABEL: util.func public @fold_unit_dims_from_extract_dynamic
440+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<1x?x1xf32>
441+ // CHECK-SAME: %[[IDX0:[a-zA-Z0-9]+]]: index
442+ // CHECK-SAME: %[[IDX1:[a-zA-Z0-9]+]]: index
443+ // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2]{{\]}}
444+ // CHECK-SAME: tensor<1x?x1xf32> into tensor<?xf32>
445+ // CHECK: %[[EXTRACT:.+]] = tensor.extract %[[COLLAPSED]][%[[IDX1]]]
446+ // CHECK: util.return %[[EXTRACT]] : f32
447+
448+ // -----
449+
450+ util.func @fold_unit_dims_from_extract_all_unit (%arg0: tensor <1 x1 x1 xf32 >, %idx0: index , %idx1: index , %idx2: index ) -> f32 {
451+ %extracted = tensor.extract %arg0 [%idx0 , %idx1 , %idx2 ] : tensor <1 x1 x1 xf32 >
452+ util.return %extracted : f32
453+ }
454+ // CHECK-LABEL: util.func public @fold_unit_dims_from_extract_all_unit
455+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<1x1x1xf32>
456+ // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] []
457+ // CHECK-SAME: tensor<1x1x1xf32> into tensor<f32>
458+ // CHECK: %[[EXTRACT:.+]] = tensor.extract %[[COLLAPSED]]
459+ // CHECK-SAME: tensor<f32>
460+ // CHECK: util.return %[[EXTRACT]] : f32
0 commit comments