@@ -480,3 +480,98 @@ func.func @do_not_fuse_multiple_stores_on_diff_indices(
480480// CHECK: scf.reduce
481481// CHECK: }
482482// CHECK: memref.dealloc [[SUM]]
483+
484+ // -----
485+
486+ func.func @fuse_same_indices_by_affine_apply (
487+ %A: memref <2 x2 xf32 >, %B: memref <2 x2 xf32 >) {
488+ %c0 = arith.constant 0 : index
489+ %c1 = arith.constant 1 : index
490+ %c2 = arith.constant 2 : index
491+ %sum = memref.alloc () : memref <2 x3 xf32 >
492+ scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) {
493+ %B_elem = memref.load %B [%i , %j ] : memref <2 x2 xf32 >
494+ %1 = affine.apply affine_map <(d0 , d1 ) -> (d0 + d1 )>(%i , %j )
495+ memref.store %B_elem , %sum [%i , %1 ] : memref <2 x3 xf32 >
496+ scf.reduce
497+ }
498+ scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) {
499+ %1 = affine.apply affine_map <(d0 , d1 ) -> (d0 + d1 )>(%i , %j )
500+ %sum_elem = memref.load %sum [%i , %1 ] : memref <2 x3 xf32 >
501+ %A_elem = memref.load %A [%i , %j ] : memref <2 x2 xf32 >
502+ %product = arith.mulf %sum_elem , %A_elem : f32
503+ memref.store %product , %B [%i , %j ] : memref <2 x2 xf32 >
504+ scf.reduce
505+ }
506+ memref.dealloc %sum : memref <2 x3 xf32 >
507+ return
508+ }
509+ // CHECK: #[[$MAP:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
510+ // CHECK-LABEL: fuse_same_indices_by_affine_apply
511+ // CHECK-SAME: (%[[ARG0:.*]]: memref<2x2xf32>, %[[ARG1:.*]]: memref<2x2xf32>) {
512+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
513+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
514+ // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
515+ // CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<2x3xf32>
516+ // CHECK-NEXT: scf.parallel (%[[ARG2:.*]], %[[ARG3:.*]]) = (%[[C0]], %[[C0]]) to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]]) {
517+ // CHECK-NEXT: %[[S0:.*]] = memref.load %[[ARG1]][%[[ARG2]], %[[ARG3]]] : memref<2x2xf32>
518+ // CHECK-NEXT: %[[S1:.*]] = affine.apply #[[$MAP]](%[[ARG2]], %[[ARG3]])
519+ // CHECK-NEXT: memref.store %[[S0]], %[[ALLOC]][%[[ARG2]], %[[S1]]] : memref<2x3xf32>
520+ // CHECK-NEXT: %[[S2:.*]] = affine.apply #[[$MAP]](%[[ARG2]], %[[ARG3]])
521+ // CHECK-NEXT: %[[S3:.*]] = memref.load %[[ALLOC]][%[[ARG2]], %[[S2]]] : memref<2x3xf32>
522+ // CHECK-NEXT: %[[S4:.*]] = memref.load %[[ARG0]][%[[ARG2]], %[[ARG3]]] : memref<2x2xf32>
523+ // CHECK-NEXT: %[[S5:.*]] = arith.mulf %[[S3]], %[[S4]] : f32
524+ // CHECK-NEXT: memref.store %[[S5]], %[[ARG1]][%[[ARG2]], %[[ARG3]]] : memref<2x2xf32>
525+ // CHECK-NEXT: scf.reduce
526+ // CHECK-NEXT: }
527+ // CHECK-NEXT: memref.dealloc %[[ALLOC]] : memref<2x3xf32>
528+ // CHECK-NEXT: return
529+
530+ // -----
531+
532+ func.func @do_not_fuse_affine_apply_to_non_ind_var (
533+ %A: memref <2 x2 xf32 >, %B: memref <2 x2 xf32 >, %OffsetA: index , %OffsetB: index ) {
534+ %c0 = arith.constant 0 : index
535+ %c1 = arith.constant 1 : index
536+ %c2 = arith.constant 2 : index
537+ %sum = memref.alloc () : memref <2 x3 xf32 >
538+ scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) {
539+ %B_elem = memref.load %B [%i , %j ] : memref <2 x2 xf32 >
540+ %1 = affine.apply affine_map <(d0 , d1 ) -> (d0 + d1 )>(%i , %OffsetA )
541+ memref.store %B_elem , %sum [%i , %1 ] : memref <2 x3 xf32 >
542+ scf.reduce
543+ }
544+ scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) {
545+ %1 = affine.apply affine_map <(d0 , d1 ) -> (d0 + d1 )>(%i , %OffsetB )
546+ %sum_elem = memref.load %sum [%i , %1 ] : memref <2 x3 xf32 >
547+ %A_elem = memref.load %A [%i , %j ] : memref <2 x2 xf32 >
548+ %product = arith.mulf %sum_elem , %A_elem : f32
549+ memref.store %product , %B [%i , %j ] : memref <2 x2 xf32 >
550+ scf.reduce
551+ }
552+ memref.dealloc %sum : memref <2 x3 xf32 >
553+ return
554+ }
555+ // CHECK: #[[$MAP:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
556+ // CHECK-LABEL: do_not_fuse_affine_apply_to_non_ind_var
557+ // CHECK-SAME: (%[[ARG0:.*]]: memref<2x2xf32>, %[[ARG1:.*]]: memref<2x2xf32>, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) {
558+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
559+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
560+ // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
561+ // CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<2x3xf32>
562+ // CHECK-NEXT: scf.parallel (%[[ARG4:.*]], %[[ARG5:.*]]) = (%[[C0]], %[[C0]]) to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]]) {
563+ // CHECK-NEXT: %[[S0:.*]] = memref.load %[[ARG1]][%[[ARG4]], %[[ARG5]]] : memref<2x2xf32>
564+ // CHECK-NEXT: %[[S1:.*]] = affine.apply #[[$MAP]](%[[ARG4]], %[[ARG2]])
565+ // CHECK-NEXT: memref.store %[[S0]], %[[ALLOC]][%[[ARG4]], %[[S1]]] : memref<2x3xf32>
566+ // CHECK-NEXT: scf.reduce
567+ // CHECK-NEXT: }
568+ // CHECK-NEXT: scf.parallel (%[[ARG4:.*]], %[[ARG5:.*]]) = (%[[C0]], %[[C0]]) to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]]) {
569+ // CHECK-NEXT: %[[S0:.*]] = affine.apply #[[$MAP]](%[[ARG4]], %[[ARG3]])
570+ // CHECK-NEXT: %[[S1:.*]] = memref.load %[[ALLOC]][%[[ARG4]], %[[S0]]] : memref<2x3xf32>
571+ // CHECK-NEXT: %[[S2:.*]] = memref.load %[[ARG0]][%[[ARG4]], %[[ARG5]]] : memref<2x2xf32>
572+ // CHECK-NEXT: %[[S3:.*]] = arith.mulf %[[S1]], %[[S2]] : f32
573+ // CHECK-NEXT: memref.store %[[S3]], %[[ARG1]][%[[ARG4]], %[[ARG5]]] : memref<2x2xf32>
574+ // CHECK-NEXT: scf.reduce
575+ // CHECK-NEXT: }
576+ // CHECK-NEXT: memref.dealloc %[[ALLOC]] : memref<2x3xf32>
577+ // CHECK-NEXT: return
0 commit comments