-
Notifications
You must be signed in to change notification settings - Fork 15.4k
Description
Hi, I'm developing a compiler for ONNX based on MLIR. I'm trying to optimize the code generation with the Affine passes, but I need help. Some error occurs in the LayerNormalization operator. Here is the code generated by my compiler automatically.
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, 0)>
#map2 = affine_map<(d0, d1, d2) -> (d2)>
module {
func.func @layer_normalization(%arg0: memref<1x128x768xf32>, %arg1: memref<1x128x768xf32>, %arg2: memref<512xi8>) attributes {llvm.emit_c_interface} {
%c0 = arith.constant 0 : index
%view = memref.view %arg2[%c0][] : memref<512xi8> to memref<512xi8>
%cst = arith.constant dense<1.000000e+00> : tensor<768xf32>
%cst_0 = arith.constant dense<0.000000e+00> : tensor<768xf32>
%c0_1 = arith.constant 0 : index
%cst_2 = arith.constant 0.000000e+00 : f32
%cst_3 = arith.constant 1.000000e+00 : f32
%view_4 = memref.view %view[%c0_1][] : memref<512xi8> to memref<1x128x1xf32>
%0 = bufferization.to_memref %cst : memref<768xf32>
%1 = bufferization.to_memref %cst_0 : memref<768xf32>
linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg1 : memref<1x128x768xf32>) outs(%view_4 : memref<1x128x1xf32>) {
^bb0(%in: f32, %out: f32):
%2 = arith.addf %in, %out : f32
linalg.yield %2 : f32
}
linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%view_4 : memref<1x128x1xf32>) outs(%view_4 : memref<1x128x1xf32>) {
^bb0(%in: f32, %out: f32):
%cst_5 = arith.constant 7.680000e+02 : f32
%2 = arith.divf %in, %cst_5 : f32
linalg.yield %2 : f32
}
linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg1, %view_4 : memref<1x128x768xf32>, memref<1x128x1xf32>) outs(%arg0 : memref<1x128x768xf32>) {
^bb0(%in: f32, %in_5: f32, %out: f32):
%2 = arith.subf %in, %in_5 : f32
linalg.yield %2 : f32
}
linalg.fill ins(%cst_2 : f32) outs(%view_4 : memref<1x128x1xf32>)
linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : memref<1x128x768xf32>) outs(%view_4 : memref<1x128x1xf32>) {
^bb0(%in: f32, %out: f32):
%2 = arith.mulf %in, %in : f32
%3 = arith.addf %out, %2 : f32
linalg.yield %3 : f32
}
linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%view_4 : memref<1x128x1xf32>) outs(%view_4 : memref<1x128x1xf32>) {
^bb0(%in: f32, %out: f32):
%cst_5 = arith.constant 7.680000e+02 : f32
%2 = arith.divf %in, %cst_5 : f32
linalg.yield %2 : f32
}
linalg.generic {indexing_maps = [#map, #map1, #map2, #map2, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %view_4, %0, %1 : memref<1x128x768xf32>, memref<1x128x1xf32>, memref<768xf32>, memref<768xf32>) outs(%arg0 : memref<1x128x768xf32>) {
^bb0(%in: f32, %in_5: f32, %in_6: f32, %in_7: f32, %out: f32):
%cst_8 = arith.constant 9.99999996E-13 : f32
%2 = arith.addf %in_5, %cst_8 : f32
%3 = math.sqrt %2 : f32
%4 = arith.divf %cst_3, %3 : f32
%5 = arith.mulf %in, %4 : f32
%6 = arith.mulf %5, %in_6 : f32
%7 = arith.addf %6, %in_7 : f32
linalg.yield %7 : f32
}
return
}
}Then, I use mlir-opt-18 -convert-linalg-to-affine-loops <filename> to lower it to the Affine dialect.
module {
func.func @layer_normalization(%arg0: memref<1x128x768xf32>, %arg1: memref<1x128x768xf32>, %arg2: memref<512xi8>) attributes {llvm.emit_c_interface} {
%cst = arith.constant 9.99999996E-13 : f32
%cst_0 = arith.constant 7.680000e+02 : f32
%cst_1 = arith.constant 1.000000e+00 : f32
%cst_2 = arith.constant 0.000000e+00 : f32
%cst_3 = arith.constant dense<0.000000e+00> : tensor<768xf32>
%cst_4 = arith.constant dense<1.000000e+00> : tensor<768xf32>
%c0 = arith.constant 0 : index
%view = memref.view %arg2[%c0][] : memref<512xi8> to memref<512xi8>
%view_5 = memref.view %view[%c0][] : memref<512xi8> to memref<1x128x1xf32>
%0 = bufferization.to_memref %cst_4 : memref<768xf32>
%1 = bufferization.to_memref %cst_3 : memref<768xf32>
affine.for %arg3 = 0 to 1 {
affine.for %arg4 = 0 to 128 {
affine.for %arg5 = 0 to 768 {
%2 = affine.load %arg1[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
%3 = affine.load %view_5[%arg3, %arg4, %c0] : memref<1x128x1xf32>
%4 = arith.addf %2, %3 : f32
affine.store %4, %view_5[%arg3, %arg4, %c0] : memref<1x128x1xf32>
}
}
}
affine.for %arg3 = 0 to 1 {
affine.for %arg4 = 0 to 128 {
affine.for %arg5 = 0 to 1 {
%2 = affine.load %view_5[%arg3, %arg4, %arg5] : memref<1x128x1xf32>
%3 = arith.divf %2, %cst_0 : f32
affine.store %3, %view_5[%arg3, %arg4, %arg5] : memref<1x128x1xf32>
}
}
}
affine.for %arg3 = 0 to 1 {
affine.for %arg4 = 0 to 128 {
affine.for %arg5 = 0 to 768 {
%2 = affine.load %arg1[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
%3 = affine.load %view_5[%arg3, %arg4, %c0] : memref<1x128x1xf32>
%4 = arith.subf %2, %3 : f32
affine.store %4, %arg0[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
}
}
}
affine.for %arg3 = 0 to 1 {
affine.for %arg4 = 0 to 128 {
affine.for %arg5 = 0 to 1 {
affine.store %cst_2, %view_5[%arg3, %arg4, %arg5] : memref<1x128x1xf32>
}
}
}
affine.for %arg3 = 0 to 1 {
affine.for %arg4 = 0 to 128 {
affine.for %arg5 = 0 to 768 {
%2 = affine.load %arg0[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
%3 = affine.load %view_5[%arg3, %arg4, %c0] : memref<1x128x1xf32>
%4 = arith.mulf %2, %2 : f32
%5 = arith.addf %3, %4 : f32
affine.store %5, %view_5[%arg3, %arg4, %c0] : memref<1x128x1xf32>
}
}
}
affine.for %arg3 = 0 to 1 {
affine.for %arg4 = 0 to 128 {
affine.for %arg5 = 0 to 1 {
%2 = affine.load %view_5[%arg3, %arg4, %arg5] : memref<1x128x1xf32>
%3 = arith.divf %2, %cst_0 : f32
affine.store %3, %view_5[%arg3, %arg4, %arg5] : memref<1x128x1xf32>
}
}
}
affine.for %arg3 = 0 to 1 {
affine.for %arg4 = 0 to 128 {
affine.for %arg5 = 0 to 768 {
%2 = affine.load %arg0[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
%3 = affine.load %view_5[%arg3, %arg4, %c0] : memref<1x128x1xf32>
%4 = affine.load %0[%arg5] : memref<768xf32>
%5 = affine.load %1[%arg5] : memref<768xf32>
%6 = arith.addf %3, %cst : f32
%7 = math.sqrt %6 : f32
%8 = arith.divf %cst_1, %7 : f32
%9 = arith.mulf %2, %8 : f32
%10 = arith.mulf %9, %4 : f32
%11 = arith.addf %10, %5 : f32
affine.store %11, %arg0[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
}
}
}
return
}
}After that, I decided to use the Affine-Loop-Fusion pass to optimize it with mlir-opt-18 -convert-linalg-to-affine-loops -affine-loop-fusion=mode=greedy <filename>.
module {
func.func @layer_normalization(%arg0: memref<1x128x768xf32>, %arg1: memref<1x128x768xf32>, %arg2: memref<512xi8>) attributes {llvm.emit_c_interface} {
%c0 = arith.constant 0 : index
%c0_0 = arith.constant 0 : index
%c0_1 = arith.constant 0 : index
%c0_2 = arith.constant 0 : index
%c0_3 = arith.constant 0 : index
%c0_4 = arith.constant 0 : index
%c0_5 = arith.constant 0 : index
%c0_6 = arith.constant 0 : index
%c0_7 = arith.constant 0 : index
%cst = arith.constant 9.99999996E-13 : f32
%cst_8 = arith.constant 7.680000e+02 : f32
%cst_9 = arith.constant 1.000000e+00 : f32
%cst_10 = arith.constant 0.000000e+00 : f32
%cst_11 = arith.constant dense<0.000000e+00> : tensor<768xf32>
%cst_12 = arith.constant dense<1.000000e+00> : tensor<768xf32>
%c0_13 = arith.constant 0 : index
%view = memref.view %arg2[%c0_13][] : memref<512xi8> to memref<512xi8>
%view_14 = memref.view %view[%c0_13][] : memref<512xi8> to memref<1x128x1xf32>
%0 = bufferization.to_memref %cst_12 : memref<768xf32>
%1 = bufferization.to_memref %cst_11 : memref<768xf32>
affine.for %arg3 = 0 to 1 {
affine.for %arg4 = 0 to 128 {
affine.for %arg5 = 0 to 768 {
%4 = affine.load %arg1[%c0, %arg4, %arg5] : memref<1x128x768xf32>
%5 = affine.load %view_14[%c0, %arg4, %c0_13] : memref<1x128x1xf32>
%6 = arith.addf %4, %5 : f32
affine.store %6, %view_14[%c0, %arg4, %c0_13] : memref<1x128x1xf32>
}
%2 = affine.load %view_14[%c0_1, %arg4, %c0_0] : memref<1x128x1xf32>
%3 = arith.divf %2, %cst_8 : f32
affine.store %3, %view_14[%c0_1, %arg4, %c0_0] : memref<1x128x1xf32>
affine.for %arg5 = 0 to 768 {
%4 = affine.load %arg1[%c0_2, %arg4, %arg5] : memref<1x128x768xf32>
%5 = affine.load %view_14[%c0_2, %arg4, %c0_13] : memref<1x128x1xf32>
%6 = arith.subf %4, %5 : f32
affine.store %6, %arg0[%c0_2, %arg4, %arg5] : memref<1x128x768xf32>
}
affine.store %cst_10, %view_14[%c0_4, %arg4, %c0_3] : memref<1x128x1xf32>
affine.for %arg5 = 0 to 768 {
%4 = affine.load %arg0[%c0_5, %arg4, %arg5] : memref<1x128x768xf32>
%5 = affine.load %view_14[%c0_5, %arg4, %c0_13] : memref<1x128x1xf32>
%6 = arith.mulf %4, %4 : f32
%7 = arith.addf %5, %6 : f32
affine.store %7, %view_14[%c0_5, %arg4, %c0_13] : memref<1x128x1xf32>
}
affine.for %arg5 = 0 to 768 {
%4 = affine.load %view_14[%c0_7, %arg4, %c0_6] : memref<1x128x1xf32>
%5 = arith.divf %4, %cst_8 : f32
affine.store %5, %view_14[%c0_7, %arg4, %c0_6] : memref<1x128x1xf32>
%6 = affine.load %arg0[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
%7 = affine.load %view_14[%arg3, %arg4, %c0_13] : memref<1x128x1xf32>
%8 = affine.load %0[%arg5] : memref<768xf32>
%9 = affine.load %1[%arg5] : memref<768xf32>
%10 = arith.addf %7, %cst : f32
%11 = math.sqrt %10 : f32
%12 = arith.divf %cst_9, %11 : f32
%13 = arith.mulf %6, %12 : f32
%14 = arith.mulf %13, %8 : f32
%15 = arith.addf %14, %9 : f32
affine.store %15, %arg0[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
}
}
}
return
}
}Please take a look at the last affine-for loop. The comparison follows:
// ...
affine.for %arg3 = 0 to 1 {
affine.for %arg4 = 0 to 128 {
affine.for %arg5 = 0 to 1 {
%2 = affine.load %view_5[%arg3, %arg4, %arg5] : memref<1x128x1xf32>
%3 = arith.divf %2, %cst_0 : f32
affine.store %3, %view_5[%arg3, %arg4, %arg5] : memref<1x128x1xf32>
}
}
}
affine.for %arg3 = 0 to 1 {
affine.for %arg4 = 0 to 128 {
affine.for %arg5 = 0 to 768 {
%2 = affine.load %arg0[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
%3 = affine.load %view_5[%arg3, %arg4, %c0] : memref<1x128x1xf32>
%4 = affine.load %0[%arg5] : memref<768xf32>
%5 = affine.load %1[%arg5] : memref<768xf32>
%6 = arith.addf %3, %cst : f32
%7 = math.sqrt %6 : f32
%8 = arith.divf %cst_1, %7 : f32
%9 = arith.mulf %2, %8 : f32
%10 = arith.mulf %9, %4 : f32
%11 = arith.addf %10, %5 : f32
affine.store %11, %arg0[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
}
}
}
affine.for %arg3 = 0 to 1 {
affine.for %arg4 = 0 to 128 {
// ...
affine.for %arg5 = 0 to 768 {
%4 = affine.load %view_14[%c0_7, %arg4, %c0_6] : memref<1x128x1xf32>
%5 = arith.divf %4, %cst_8 : f32
affine.store %5, %view_14[%c0_7, %arg4, %c0_6] : memref<1x128x1xf32>
%6 = affine.load %arg0[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
%7 = affine.load %view_14[%arg3, %arg4, %c0_13] : memref<1x128x1xf32>
%8 = affine.load %0[%arg5] : memref<768xf32>
%9 = affine.load %1[%arg5] : memref<768xf32>
%10 = arith.addf %7, %cst : f32
%11 = math.sqrt %10 : f32
%12 = arith.divf %cst_9, %11 : f32
%13 = arith.mulf %6, %12 : f32
%14 = arith.mulf %13, %8 : f32
%15 = arith.addf %14, %9 : f32
affine.store %15, %arg0[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
}
}
}The pass fuses the divf instruction into the loop error. In the 1-128 loop, the divf should be executed only once, but due to the wrong fusion, it will be executed 768 times instead. I also examined it in the real example, and the output changes after the pass, as we see in the code.
Could you please provide me with some information on this issue? Is it a bug, or is something wrong with my code and optimization? Thank you so much for your reading!