Skip to content

IterativeTilingAndFusion failed with latest config enhanced deep tile matmul #315

@yifeizh2

Description

@yifeizh2

Failed case (IR after deep tile matmul):

#map = affine_map<(d0) -> (d0 * 32)>
module {
  func.func @main_entry(%arg0: tensor<128x512xbf16>, %arg1: tensor<512x1024xbf16>, %arg2: tensor<1024xbf16>) -> tensor<128x1024xbf16> attributes {llvm.emit_c_interface} {
    %c32 = arith.constant 32 : index
    %c64 = arith.constant 64 : index
    %c1 = arith.constant 1 : index
    %c2 = arith.constant 2 : index
    %c0 = arith.constant 0 : index
    %cst = arith.constant dense<0.000000e+00> : tensor<128x1024xbf16>
    %0 = tensor.empty() : tensor<128x1024xbf16>
    %1 = tensor.empty() : tensor<32x16x32x32xbf16>
    %pack = tensor.pack %arg1 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %1 : tensor<512x1024xbf16> -> tensor<32x16x32x32xbf16>
    %2 = tensor.empty() : tensor<32x16x16x32x2xbf16>
    %pack_0 = tensor.pack %pack inner_dims_pos = [2] inner_tiles = [2] into %2 : tensor<32x16x32x32xbf16> -> tensor<32x16x16x32x2xbf16>
    %3 = scf.forall (%arg3) = (0) to (128) step (64) shared_outs(%arg4 = %0) -> (tensor<128x1024xbf16>) {
      %extracted_slice = tensor.extract_slice %arg0[%arg3, 0] [64, 512] [1, 1] : tensor<128x512xbf16> to tensor<64x512xbf16>
      %extracted_slice_1 = tensor.extract_slice %arg4[%arg3, 0] [64, 1024] [1, 1] : tensor<128x1024xbf16> to tensor<64x1024xbf16>
      %9 = scf.forall (%arg5) = (0) to (32) step (2) shared_outs(%arg6 = %extracted_slice_1) -> (tensor<64x1024xbf16>) {
        %10 = affine.apply #map(%arg5)
        %extracted_slice_2 = tensor.extract_slice %pack_0[%arg5, 0, 0, 0, 0] [2, 16, 16, 32, 2] [1, 1, 1, 1, 1] : tensor<32x16x16x32x2xbf16> to tensor<2x16x16x32x2xbf16>
        %extracted_slice_3 = tensor.extract_slice %arg6[0, %10] [64, 64] [1, 1] : tensor<64x1024xbf16> to tensor<64x64xbf16>
        %11 = scf.for %arg7 = %c0 to %c2 step %c1 iter_args(%arg8 = %extracted_slice_3) -> (tensor<64x64xbf16>) {
          %13 = affine.apply #map(%arg7)
          %extracted_slice_4 = tensor.extract_slice %extracted_slice_2[%arg7, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : tensor<2x16x16x32x2xbf16> to tensor<1x16x16x32x2xbf16>
          %extracted_slice_5 = tensor.extract_slice %arg8[0, %13] [64, 32] [1, 1] : tensor<64x64xbf16> to tensor<64x32xbf16>
          %14 = tensor.empty() : tensor<64x32xf32>
          %15 = linalg.copy ins(%extracted_slice_5 : tensor<64x32xbf16>) outs(%14 : tensor<64x32xf32>) -> tensor<64x32xf32>
          %16:2 = scf.for %arg9 = %c0 to %c64 step %c32 iter_args(%arg10 = %15, %arg11 = %extracted_slice_5) -> (tensor<64x32xf32>, tensor<64x32xbf16>) {
            %extracted_slice_6 = tensor.extract_slice %extracted_slice[%arg9, 0] [32, 512] [1, 1] : tensor<64x512xbf16> to tensor<32x512xbf16>
            %collapsed = tensor.collapse_shape %extracted_slice_4 [[0, 1], [2], [3], [4]] : tensor<1x16x16x32x2xbf16> into tensor<16x16x32x2xbf16>
            %extracted_slice_7 = tensor.extract_slice %arg10[%arg9, 0] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
            %extracted_slice_8 = tensor.extract_slice %arg11[%arg9, 0] [32, 32] [1, 1] : tensor<64x32xbf16> to tensor<32x32xbf16>
            %expanded = tensor.expand_shape %extracted_slice_6 [[0], [1, 2]] output_shape [32, 16, 32] : tensor<32x512xbf16> into tensor<32x16x32xbf16>
            %18 = tensor.empty() : tensor<16x32x32xbf16>
            %transposed = linalg.transpose ins(%expanded : tensor<32x16x32xbf16>) outs(%18 : tensor<16x32x32xbf16>) permutation = [1, 0, 2] 
            %19 = linalgx.batch_reduce_matmul_vnni ins(%transposed, %collapsed : tensor<16x32x32xbf16>, tensor<16x16x32x2xbf16>) outs(%extracted_slice_7 : tensor<32x32xf32>) -> tensor<32x32xf32>
            %20 = linalg.copy ins(%19 : tensor<32x32xf32>) outs(%extracted_slice_8 : tensor<32x32xbf16>) -> tensor<32x32xbf16>
            %inserted_slice_9 = tensor.insert_slice %19 into %arg10[%arg9, 0] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
            %inserted_slice_10 = tensor.insert_slice %20 into %arg11[%arg9, 0] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x32xbf16>
            scf.yield %inserted_slice_9, %inserted_slice_10 : tensor<64x32xf32>, tensor<64x32xbf16>
          }
          %17 = affine.apply #map(%arg7)
          %inserted_slice = tensor.insert_slice %16#1 into %arg8[0, %17] [64, 32] [1, 1] : tensor<64x32xbf16> into tensor<64x64xbf16>
          scf.yield %inserted_slice : tensor<64x64xbf16>
        }
        %12 = affine.apply #map(%arg5)
        scf.forall.in_parallel {
          tensor.parallel_insert_slice %11 into %arg6[0, %12] [64, 64] [1, 1] : tensor<64x64xbf16> into tensor<64x1024xbf16>
        }
      }
      scf.forall.in_parallel {
        tensor.parallel_insert_slice %9 into %arg4[%arg3, 0] [64, 1024] [1, 1] : tensor<64x1024xbf16> into tensor<128x1024xbf16>
      }
    }
    %4 = tensor.empty() : tensor<128x1024xbf16>
    %broadcasted = linalg.broadcast ins(%arg2 : tensor<1024xbf16>) outs(%4 : tensor<128x1024xbf16>) dimensions = [0] 
    %5 = tensor.empty() : tensor<128x1024xbf16>
    %6 = linalg.add ins(%3, %broadcasted : tensor<128x1024xbf16>, tensor<128x1024xbf16>) outs(%5 : tensor<128x1024xbf16>) -> tensor<128x1024xbf16>
    %7 = tensor.empty() : tensor<128x1024xbf16>
    %8 = linalg.max ins(%6, %cst : tensor<128x1024xbf16>, tensor<128x1024xbf16>) outs(%7 : tensor<128x1024xbf16>) -> tensor<128x1024xbf16>
    return %8 : tensor<128x1024xbf16>
  }
}

Command for reproduce:

./bin/gc-opt --iterative-tiling-and-fusion debug.mlir --debug

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions