-
Notifications
You must be signed in to change notification settings - Fork 16
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Failed case (IR after deep tile matmul):
#map = affine_map<(d0) -> (d0 * 32)>
module {
func.func @main_entry(%arg0: tensor<128x512xbf16>, %arg1: tensor<512x1024xbf16>, %arg2: tensor<1024xbf16>) -> tensor<128x1024xbf16> attributes {llvm.emit_c_interface} {
%c32 = arith.constant 32 : index
%c64 = arith.constant 64 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
%cst = arith.constant dense<0.000000e+00> : tensor<128x1024xbf16>
%0 = tensor.empty() : tensor<128x1024xbf16>
%1 = tensor.empty() : tensor<32x16x32x32xbf16>
%pack = tensor.pack %arg1 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %1 : tensor<512x1024xbf16> -> tensor<32x16x32x32xbf16>
%2 = tensor.empty() : tensor<32x16x16x32x2xbf16>
%pack_0 = tensor.pack %pack inner_dims_pos = [2] inner_tiles = [2] into %2 : tensor<32x16x32x32xbf16> -> tensor<32x16x16x32x2xbf16>
%3 = scf.forall (%arg3) = (0) to (128) step (64) shared_outs(%arg4 = %0) -> (tensor<128x1024xbf16>) {
%extracted_slice = tensor.extract_slice %arg0[%arg3, 0] [64, 512] [1, 1] : tensor<128x512xbf16> to tensor<64x512xbf16>
%extracted_slice_1 = tensor.extract_slice %arg4[%arg3, 0] [64, 1024] [1, 1] : tensor<128x1024xbf16> to tensor<64x1024xbf16>
%9 = scf.forall (%arg5) = (0) to (32) step (2) shared_outs(%arg6 = %extracted_slice_1) -> (tensor<64x1024xbf16>) {
%10 = affine.apply #map(%arg5)
%extracted_slice_2 = tensor.extract_slice %pack_0[%arg5, 0, 0, 0, 0] [2, 16, 16, 32, 2] [1, 1, 1, 1, 1] : tensor<32x16x16x32x2xbf16> to tensor<2x16x16x32x2xbf16>
%extracted_slice_3 = tensor.extract_slice %arg6[0, %10] [64, 64] [1, 1] : tensor<64x1024xbf16> to tensor<64x64xbf16>
%11 = scf.for %arg7 = %c0 to %c2 step %c1 iter_args(%arg8 = %extracted_slice_3) -> (tensor<64x64xbf16>) {
%13 = affine.apply #map(%arg7)
%extracted_slice_4 = tensor.extract_slice %extracted_slice_2[%arg7, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : tensor<2x16x16x32x2xbf16> to tensor<1x16x16x32x2xbf16>
%extracted_slice_5 = tensor.extract_slice %arg8[0, %13] [64, 32] [1, 1] : tensor<64x64xbf16> to tensor<64x32xbf16>
%14 = tensor.empty() : tensor<64x32xf32>
%15 = linalg.copy ins(%extracted_slice_5 : tensor<64x32xbf16>) outs(%14 : tensor<64x32xf32>) -> tensor<64x32xf32>
%16:2 = scf.for %arg9 = %c0 to %c64 step %c32 iter_args(%arg10 = %15, %arg11 = %extracted_slice_5) -> (tensor<64x32xf32>, tensor<64x32xbf16>) {
%extracted_slice_6 = tensor.extract_slice %extracted_slice[%arg9, 0] [32, 512] [1, 1] : tensor<64x512xbf16> to tensor<32x512xbf16>
%collapsed = tensor.collapse_shape %extracted_slice_4 [[0, 1], [2], [3], [4]] : tensor<1x16x16x32x2xbf16> into tensor<16x16x32x2xbf16>
%extracted_slice_7 = tensor.extract_slice %arg10[%arg9, 0] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
%extracted_slice_8 = tensor.extract_slice %arg11[%arg9, 0] [32, 32] [1, 1] : tensor<64x32xbf16> to tensor<32x32xbf16>
%expanded = tensor.expand_shape %extracted_slice_6 [[0], [1, 2]] output_shape [32, 16, 32] : tensor<32x512xbf16> into tensor<32x16x32xbf16>
%18 = tensor.empty() : tensor<16x32x32xbf16>
%transposed = linalg.transpose ins(%expanded : tensor<32x16x32xbf16>) outs(%18 : tensor<16x32x32xbf16>) permutation = [1, 0, 2]
%19 = linalgx.batch_reduce_matmul_vnni ins(%transposed, %collapsed : tensor<16x32x32xbf16>, tensor<16x16x32x2xbf16>) outs(%extracted_slice_7 : tensor<32x32xf32>) -> tensor<32x32xf32>
%20 = linalg.copy ins(%19 : tensor<32x32xf32>) outs(%extracted_slice_8 : tensor<32x32xbf16>) -> tensor<32x32xbf16>
%inserted_slice_9 = tensor.insert_slice %19 into %arg10[%arg9, 0] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
%inserted_slice_10 = tensor.insert_slice %20 into %arg11[%arg9, 0] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x32xbf16>
scf.yield %inserted_slice_9, %inserted_slice_10 : tensor<64x32xf32>, tensor<64x32xbf16>
}
%17 = affine.apply #map(%arg7)
%inserted_slice = tensor.insert_slice %16#1 into %arg8[0, %17] [64, 32] [1, 1] : tensor<64x32xbf16> into tensor<64x64xbf16>
scf.yield %inserted_slice : tensor<64x64xbf16>
}
%12 = affine.apply #map(%arg5)
scf.forall.in_parallel {
tensor.parallel_insert_slice %11 into %arg6[0, %12] [64, 64] [1, 1] : tensor<64x64xbf16> into tensor<64x1024xbf16>
}
}
scf.forall.in_parallel {
tensor.parallel_insert_slice %9 into %arg4[%arg3, 0] [64, 1024] [1, 1] : tensor<64x1024xbf16> into tensor<128x1024xbf16>
}
}
%4 = tensor.empty() : tensor<128x1024xbf16>
%broadcasted = linalg.broadcast ins(%arg2 : tensor<1024xbf16>) outs(%4 : tensor<128x1024xbf16>) dimensions = [0]
%5 = tensor.empty() : tensor<128x1024xbf16>
%6 = linalg.add ins(%3, %broadcasted : tensor<128x1024xbf16>, tensor<128x1024xbf16>) outs(%5 : tensor<128x1024xbf16>) -> tensor<128x1024xbf16>
%7 = tensor.empty() : tensor<128x1024xbf16>
%8 = linalg.max ins(%6, %cst : tensor<128x1024xbf16>, tensor<128x1024xbf16>) outs(%7 : tensor<128x1024xbf16>) -> tensor<128x1024xbf16>
return %8 : tensor<128x1024xbf16>
}
}
Command for reproduce:
./bin/gc-opt --iterative-tiling-and-fusion debug.mlir --debug
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working