Skip to content

Unexpected Behavior on Affine-Loop-Fusion #95230

@sgjzfzzf

Description

@sgjzfzzf

Hi, I'm developing a compiler for ONNX based on MLIR. I'm trying to optimize the code generation with the Affine passes, but I need help. Some error occurs in the LayerNormalization operator. Here is the code generated by my compiler automatically.

#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, 0)>
#map2 = affine_map<(d0, d1, d2) -> (d2)>
module {
  func.func @layer_normalization(%arg0: memref<1x128x768xf32>, %arg1: memref<1x128x768xf32>, %arg2: memref<512xi8>) attributes {llvm.emit_c_interface} {
    %c0 = arith.constant 0 : index
    %view = memref.view %arg2[%c0][] : memref<512xi8> to memref<512xi8>
    %cst = arith.constant dense<1.000000e+00> : tensor<768xf32>
    %cst_0 = arith.constant dense<0.000000e+00> : tensor<768xf32>
    %c0_1 = arith.constant 0 : index
    %cst_2 = arith.constant 0.000000e+00 : f32
    %cst_3 = arith.constant 1.000000e+00 : f32
    %view_4 = memref.view %view[%c0_1][] : memref<512xi8> to memref<1x128x1xf32>
    %0 = bufferization.to_memref %cst : memref<768xf32>
    %1 = bufferization.to_memref %cst_0 : memref<768xf32>
    linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg1 : memref<1x128x768xf32>) outs(%view_4 : memref<1x128x1xf32>) {
    ^bb0(%in: f32, %out: f32):
      %2 = arith.addf %in, %out : f32
      linalg.yield %2 : f32
    }
    linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%view_4 : memref<1x128x1xf32>) outs(%view_4 : memref<1x128x1xf32>) {
    ^bb0(%in: f32, %out: f32):
      %cst_5 = arith.constant 7.680000e+02 : f32
      %2 = arith.divf %in, %cst_5 : f32
      linalg.yield %2 : f32
    }
    linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg1, %view_4 : memref<1x128x768xf32>, memref<1x128x1xf32>) outs(%arg0 : memref<1x128x768xf32>) {
    ^bb0(%in: f32, %in_5: f32, %out: f32):
      %2 = arith.subf %in, %in_5 : f32
      linalg.yield %2 : f32
    }
    linalg.fill ins(%cst_2 : f32) outs(%view_4 : memref<1x128x1xf32>)
    linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : memref<1x128x768xf32>) outs(%view_4 : memref<1x128x1xf32>) {
    ^bb0(%in: f32, %out: f32):
      %2 = arith.mulf %in, %in : f32
      %3 = arith.addf %out, %2 : f32
      linalg.yield %3 : f32
    }
    linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%view_4 : memref<1x128x1xf32>) outs(%view_4 : memref<1x128x1xf32>) {
    ^bb0(%in: f32, %out: f32):
      %cst_5 = arith.constant 7.680000e+02 : f32
      %2 = arith.divf %in, %cst_5 : f32
      linalg.yield %2 : f32
    }
    linalg.generic {indexing_maps = [#map, #map1, #map2, #map2, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %view_4, %0, %1 : memref<1x128x768xf32>, memref<1x128x1xf32>, memref<768xf32>, memref<768xf32>) outs(%arg0 : memref<1x128x768xf32>) {
    ^bb0(%in: f32, %in_5: f32, %in_6: f32, %in_7: f32, %out: f32):
      %cst_8 = arith.constant 9.99999996E-13 : f32
      %2 = arith.addf %in_5, %cst_8 : f32
      %3 = math.sqrt %2 : f32
      %4 = arith.divf %cst_3, %3 : f32
      %5 = arith.mulf %in, %4 : f32
      %6 = arith.mulf %5, %in_6 : f32
      %7 = arith.addf %6, %in_7 : f32
      linalg.yield %7 : f32
    }
    return
  }
}

Then, I use mlir-opt-18 -convert-linalg-to-affine-loops <filename> to lower it to the Affine dialect.

module {
  func.func @layer_normalization(%arg0: memref<1x128x768xf32>, %arg1: memref<1x128x768xf32>, %arg2: memref<512xi8>) attributes {llvm.emit_c_interface} {
    %cst = arith.constant 9.99999996E-13 : f32
    %cst_0 = arith.constant 7.680000e+02 : f32
    %cst_1 = arith.constant 1.000000e+00 : f32
    %cst_2 = arith.constant 0.000000e+00 : f32
    %cst_3 = arith.constant dense<0.000000e+00> : tensor<768xf32>
    %cst_4 = arith.constant dense<1.000000e+00> : tensor<768xf32>
    %c0 = arith.constant 0 : index
    %view = memref.view %arg2[%c0][] : memref<512xi8> to memref<512xi8>
    %view_5 = memref.view %view[%c0][] : memref<512xi8> to memref<1x128x1xf32>
    %0 = bufferization.to_memref %cst_4 : memref<768xf32>
    %1 = bufferization.to_memref %cst_3 : memref<768xf32>
    affine.for %arg3 = 0 to 1 {
      affine.for %arg4 = 0 to 128 {
        affine.for %arg5 = 0 to 768 {
          %2 = affine.load %arg1[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
          %3 = affine.load %view_5[%arg3, %arg4, %c0] : memref<1x128x1xf32>
          %4 = arith.addf %2, %3 : f32
          affine.store %4, %view_5[%arg3, %arg4, %c0] : memref<1x128x1xf32>
        }
      }
    }
    affine.for %arg3 = 0 to 1 {
      affine.for %arg4 = 0 to 128 {
        affine.for %arg5 = 0 to 1 {
          %2 = affine.load %view_5[%arg3, %arg4, %arg5] : memref<1x128x1xf32>
          %3 = arith.divf %2, %cst_0 : f32
          affine.store %3, %view_5[%arg3, %arg4, %arg5] : memref<1x128x1xf32>
        }
      }
    }
    affine.for %arg3 = 0 to 1 {
      affine.for %arg4 = 0 to 128 {
        affine.for %arg5 = 0 to 768 {
          %2 = affine.load %arg1[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
          %3 = affine.load %view_5[%arg3, %arg4, %c0] : memref<1x128x1xf32>
          %4 = arith.subf %2, %3 : f32
          affine.store %4, %arg0[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
        }
      }
    }
    affine.for %arg3 = 0 to 1 {
      affine.for %arg4 = 0 to 128 {
        affine.for %arg5 = 0 to 1 {
          affine.store %cst_2, %view_5[%arg3, %arg4, %arg5] : memref<1x128x1xf32>
        }
      }
    }
    affine.for %arg3 = 0 to 1 {
      affine.for %arg4 = 0 to 128 {
        affine.for %arg5 = 0 to 768 {
          %2 = affine.load %arg0[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
          %3 = affine.load %view_5[%arg3, %arg4, %c0] : memref<1x128x1xf32>
          %4 = arith.mulf %2, %2 : f32
          %5 = arith.addf %3, %4 : f32
          affine.store %5, %view_5[%arg3, %arg4, %c0] : memref<1x128x1xf32>
        }
      }
    }
    affine.for %arg3 = 0 to 1 {
      affine.for %arg4 = 0 to 128 {
        affine.for %arg5 = 0 to 1 {
          %2 = affine.load %view_5[%arg3, %arg4, %arg5] : memref<1x128x1xf32>
          %3 = arith.divf %2, %cst_0 : f32
          affine.store %3, %view_5[%arg3, %arg4, %arg5] : memref<1x128x1xf32>
        }
      }
    }
    affine.for %arg3 = 0 to 1 {
      affine.for %arg4 = 0 to 128 {
        affine.for %arg5 = 0 to 768 {
          %2 = affine.load %arg0[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
          %3 = affine.load %view_5[%arg3, %arg4, %c0] : memref<1x128x1xf32>
          %4 = affine.load %0[%arg5] : memref<768xf32>
          %5 = affine.load %1[%arg5] : memref<768xf32>
          %6 = arith.addf %3, %cst : f32
          %7 = math.sqrt %6 : f32
          %8 = arith.divf %cst_1, %7 : f32
          %9 = arith.mulf %2, %8 : f32
          %10 = arith.mulf %9, %4 : f32
          %11 = arith.addf %10, %5 : f32
          affine.store %11, %arg0[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
        }
      }
    }
    return
  }
}

After that, I decided to use the Affine-Loop-Fusion pass to optimize it with mlir-opt-18 -convert-linalg-to-affine-loops -affine-loop-fusion=mode=greedy <filename>.

module {
  func.func @layer_normalization(%arg0: memref<1x128x768xf32>, %arg1: memref<1x128x768xf32>, %arg2: memref<512xi8>) attributes {llvm.emit_c_interface} {
    %c0 = arith.constant 0 : index
    %c0_0 = arith.constant 0 : index
    %c0_1 = arith.constant 0 : index
    %c0_2 = arith.constant 0 : index
    %c0_3 = arith.constant 0 : index
    %c0_4 = arith.constant 0 : index
    %c0_5 = arith.constant 0 : index
    %c0_6 = arith.constant 0 : index
    %c0_7 = arith.constant 0 : index
    %cst = arith.constant 9.99999996E-13 : f32
    %cst_8 = arith.constant 7.680000e+02 : f32
    %cst_9 = arith.constant 1.000000e+00 : f32
    %cst_10 = arith.constant 0.000000e+00 : f32
    %cst_11 = arith.constant dense<0.000000e+00> : tensor<768xf32>
    %cst_12 = arith.constant dense<1.000000e+00> : tensor<768xf32>
    %c0_13 = arith.constant 0 : index
    %view = memref.view %arg2[%c0_13][] : memref<512xi8> to memref<512xi8>
    %view_14 = memref.view %view[%c0_13][] : memref<512xi8> to memref<1x128x1xf32>
    %0 = bufferization.to_memref %cst_12 : memref<768xf32>
    %1 = bufferization.to_memref %cst_11 : memref<768xf32>
    affine.for %arg3 = 0 to 1 {
      affine.for %arg4 = 0 to 128 {
        affine.for %arg5 = 0 to 768 {
          %4 = affine.load %arg1[%c0, %arg4, %arg5] : memref<1x128x768xf32>
          %5 = affine.load %view_14[%c0, %arg4, %c0_13] : memref<1x128x1xf32>
          %6 = arith.addf %4, %5 : f32
          affine.store %6, %view_14[%c0, %arg4, %c0_13] : memref<1x128x1xf32>
        }
        %2 = affine.load %view_14[%c0_1, %arg4, %c0_0] : memref<1x128x1xf32>
        %3 = arith.divf %2, %cst_8 : f32
        affine.store %3, %view_14[%c0_1, %arg4, %c0_0] : memref<1x128x1xf32>
        affine.for %arg5 = 0 to 768 {
          %4 = affine.load %arg1[%c0_2, %arg4, %arg5] : memref<1x128x768xf32>
          %5 = affine.load %view_14[%c0_2, %arg4, %c0_13] : memref<1x128x1xf32>
          %6 = arith.subf %4, %5 : f32
          affine.store %6, %arg0[%c0_2, %arg4, %arg5] : memref<1x128x768xf32>
        }
        affine.store %cst_10, %view_14[%c0_4, %arg4, %c0_3] : memref<1x128x1xf32>
        affine.for %arg5 = 0 to 768 {
          %4 = affine.load %arg0[%c0_5, %arg4, %arg5] : memref<1x128x768xf32>
          %5 = affine.load %view_14[%c0_5, %arg4, %c0_13] : memref<1x128x1xf32>
          %6 = arith.mulf %4, %4 : f32
          %7 = arith.addf %5, %6 : f32
          affine.store %7, %view_14[%c0_5, %arg4, %c0_13] : memref<1x128x1xf32>
        }
        affine.for %arg5 = 0 to 768 {
          %4 = affine.load %view_14[%c0_7, %arg4, %c0_6] : memref<1x128x1xf32>
          %5 = arith.divf %4, %cst_8 : f32
          affine.store %5, %view_14[%c0_7, %arg4, %c0_6] : memref<1x128x1xf32>
          %6 = affine.load %arg0[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
          %7 = affine.load %view_14[%arg3, %arg4, %c0_13] : memref<1x128x1xf32>
          %8 = affine.load %0[%arg5] : memref<768xf32>
          %9 = affine.load %1[%arg5] : memref<768xf32>
          %10 = arith.addf %7, %cst : f32
          %11 = math.sqrt %10 : f32
          %12 = arith.divf %cst_9, %11 : f32
          %13 = arith.mulf %6, %12 : f32
          %14 = arith.mulf %13, %8 : f32
          %15 = arith.addf %14, %9 : f32
          affine.store %15, %arg0[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
        }
      }
    }
    return
  }
}

Please take a look at the last affine-for loop. The comparison follows:

    // ...
    affine.for %arg3 = 0 to 1 {
      affine.for %arg4 = 0 to 128 {
        affine.for %arg5 = 0 to 1 {
          %2 = affine.load %view_5[%arg3, %arg4, %arg5] : memref<1x128x1xf32>
          %3 = arith.divf %2, %cst_0 : f32
          affine.store %3, %view_5[%arg3, %arg4, %arg5] : memref<1x128x1xf32>
        }
      }
    }
    affine.for %arg3 = 0 to 1 {
      affine.for %arg4 = 0 to 128 {
        affine.for %arg5 = 0 to 768 {
          %2 = affine.load %arg0[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
          %3 = affine.load %view_5[%arg3, %arg4, %c0] : memref<1x128x1xf32>
          %4 = affine.load %0[%arg5] : memref<768xf32>
          %5 = affine.load %1[%arg5] : memref<768xf32>
          %6 = arith.addf %3, %cst : f32
          %7 = math.sqrt %6 : f32
          %8 = arith.divf %cst_1, %7 : f32
          %9 = arith.mulf %2, %8 : f32
          %10 = arith.mulf %9, %4 : f32
          %11 = arith.addf %10, %5 : f32
          affine.store %11, %arg0[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
        }
      }
    }


    affine.for %arg3 = 0 to 1 {
      affine.for %arg4 = 0 to 128 {
        // ...
        affine.for %arg5 = 0 to 768 {
          %4 = affine.load %view_14[%c0_7, %arg4, %c0_6] : memref<1x128x1xf32>
          %5 = arith.divf %4, %cst_8 : f32
          affine.store %5, %view_14[%c0_7, %arg4, %c0_6] : memref<1x128x1xf32>
          %6 = affine.load %arg0[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
          %7 = affine.load %view_14[%arg3, %arg4, %c0_13] : memref<1x128x1xf32>
          %8 = affine.load %0[%arg5] : memref<768xf32>
          %9 = affine.load %1[%arg5] : memref<768xf32>
          %10 = arith.addf %7, %cst : f32
          %11 = math.sqrt %10 : f32
          %12 = arith.divf %cst_9, %11 : f32
          %13 = arith.mulf %6, %12 : f32
          %14 = arith.mulf %13, %8 : f32
          %15 = arith.addf %14, %9 : f32
          affine.store %15, %arg0[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
        }
      }
    }

The pass fuses the divf instruction into the loop error. In the 1-128 loop, the divf should be executed only once, but due to the wrong fusion, it will be executed 768 times instead. I also examined it in the real example, and the output changes after the pass, as we see in the code.

Could you please provide me with some information on this issue? Is it a bug, or is something wrong with my code and optimization? Thank you so much for your reading!

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions