Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 0 additions & 34 deletions .github/workflows/tpp-benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@ on:
description: "Run on Zen5"
type: boolean
default: true
RUN_CLX_BENCH:
description: "Run on CLX"
type: boolean
default: true
RUN_ARL_BENCH:
description: "Run on ARL"
type: boolean
Expand Down Expand Up @@ -100,36 +96,6 @@ jobs:
${{ github.workspace }}/scripts/github/benchmark.sh -o"
${{ env.SRUN }} --partition=zen5 --time=0:30:00 -- $CMD

TPP-MLIR-CLX-BASE:
runs-on: pcl-tiergarten
if: |
(github.event_name == 'push') ||
(github.event_name == 'workflow_dispatch' && inputs.RUN_CLX_BENCH) ||
(github.event_name == 'pull_request' && contains(github.event.pull_request.labels.*.name, 'benchmark-full'))
needs: Check_LLVM
steps:
- uses: actions/checkout@v4
- name: CLX Base
run: |-
CMD="KIND=Release COMPILER=clang LINKER=lld BENCHMARK_NUM_ITER=${{ env.NUM_ITER }} \
${{ github.workspace }}/scripts/github/benchmark.sh -b -p"
${{ env.SRUN }} --partition=clxap --time=0:30:00 -- $CMD

TPP-MLIR-CLX-OMP:
runs-on: pcl-tiergarten
if: |
(github.event_name == 'push') ||
(github.event_name == 'workflow_dispatch' && inputs.RUN_CLX_BENCH) ||
(github.event_name == 'pull_request' && contains(github.event.pull_request.labels.*.name, 'benchmark-full'))
needs: Check_LLVM
steps:
- uses: actions/checkout@v4
- name: CLX OpenMP
run: |-
CMD="KIND=Release COMPILER=clang LINKER=lld BENCHMARK_NUM_ITER=${{ env.NUM_ITER }} \
${{ github.workspace }}/scripts/github/benchmark.sh -o"
${{ env.SRUN }} --partition=clxap --time=0:30:00 -- $CMD

TPP-MLIR-ARL-BASE:
runs-on: pcl-tiergarten
if: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/tpp-llvm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@ jobs:
- name: LLVM CUDA
run: |-
GPU=cuda scripts/github/check_llvm.sh || \
${{ env.SRUN }} --partition=a100,v100 --time=0:30:00 -- \
${{ env.SRUN }} --partition=a100 --time=0:30:00 -- \
'KIND=RelWithDebInfo COMPILER=clang GPU=cuda \
${{ github.workspace }}/scripts/github/build_llvm.sh'
2 changes: 1 addition & 1 deletion build_tools/llvm_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
eb6da944af31dd684be3ab2f93f453a3837a72c6
8eba28bc8ce9447d09edda6fc79e2191a1669252
4 changes: 2 additions & 2 deletions lib/TPP/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ add_mlir_library(TPPPipeline

LINK_LIBS PUBLIC
MLIRIR
${mlir_dialect_libs}
${conversion_libs}
MLIRRegisterAllDialects
MLIRRegisterAllPasses
TPPGPU
TPPPassBundles
)
14 changes: 11 additions & 3 deletions lib/TPP/DefaultPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,32 @@
//
//===----------------------------------------------------------------------===//

#include "TPP/PassBundles.h"

#include "mlir/IR/BuiltinOps.h"
#include "mlir/InitAllDialects.h"
#include "mlir/InitAllPasses.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "llvm/Support/CommandLine.h"

#include "mlir/Conversion/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/GPU/Transforms/Passes.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Async/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/PassOptions.h"
#include "mlir/Transforms/Passes.h"

#include "TPP/Dialect/Check/BufferizableOpInterfaceImpl.h"
#include "TPP/Dialect/Check/CheckDialect.h"
#include "TPP/Dialect/Perf/BufferizableOpInterfaceImpl.h"
#include "TPP/Dialect/Perf/PerfDialect.h"
#include "TPP/Dialect/Perf/PerfOps.h"
#include "TPP/Dialect/Xsmm/XsmmDialect.h"
#include "TPP/PassBundles.h"
#include "TPP/PassUtils.h"
#include "TPP/Transforms/Utils/VNNIUtils.h"
#include "mlir/Transforms/Passes.h"

#include <string>

Expand Down
1 change: 1 addition & 0 deletions lib/TPP/GPU/GpuPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/InitAllDialects.h"
#include "mlir/InitAllPasses.h"
Expand Down
22 changes: 12 additions & 10 deletions lib/TPP/Transforms/ToBlockLayoutAndBack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -459,14 +459,15 @@ namespace {
static SmallVector<int64_t>
getDefaultBlockingFactors(linalg::LinalgOp linalgOp) {
assert(linalgOp && "expect a valid linalgOp");
if (isa<linalg::Conv2DNchwFchwOp>(linalgOp) ||
isa<linalg::Conv2DNhwcHwcfOp>(linalgOp)) {
auto *op = linalgOp.getOperation();
if (isa<linalg::Conv2DNchwFchwOp>(op) ||
isa<linalg::Conv2DNhwcHwcfOp>(op)) {
return {32, 32};
}
assert(isa<linalg::MatmulOp>(linalgOp) ||
isa<linalg::BatchMatmulOp>(linalgOp) ||
isa<linalg::MatmulTransposeAOp>(linalgOp) ||
isa<linalg::MatmulTransposeBOp>(linalgOp));
assert(isa<linalg::MatmulOp>(op) ||
isa<linalg::BatchMatmulOp>(op) ||
isa<linalg::MatmulTransposeAOp>(op) ||
isa<linalg::MatmulTransposeBOp>(op));
return {32, 32, 32};
}

Expand All @@ -492,12 +493,13 @@ struct PackMatmul : public tpp::impl::PackMatmulBase<PackMatmul> {
auto packControlFn = [&](linalg::LinalgOp linalgOp)
-> std::optional<linalg::BlockPackMatmulOptions> {
linalg::BlockPackMatmulOptions options;
auto *op = linalgOp.getOperation();

// Pack only these named matmul variants.
if (!(isa<linalg::MatmulOp>(linalgOp) ||
isa<linalg::MatmulTransposeAOp>(linalgOp) ||
isa<linalg::MatmulTransposeBOp>(linalgOp) ||
isa<linalg::BatchMatmulOp>(linalgOp))) {
if (!(isa<linalg::MatmulOp>(op) ||
isa<linalg::MatmulTransposeAOp>(op) ||
isa<linalg::MatmulTransposeBOp>(op) ||
isa<linalg::BatchMatmulOp>(op))) {
return std::nullopt;
}

Expand Down
5 changes: 2 additions & 3 deletions scripts/ci/setup_gpu_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ source ${SCRIPT_DIR}/ci/common.sh
# Env CUDA setup
if [[ ${GPU,,} =~ "cuda" ]]; then
echo "Setting up CUDA environment"
echo "Hard-coding CUDA-compatible GCC version (12.3)"
source /swtools/gcc/gcc-12.3.0/gcc_vars.sh
source /swtools/cuda/latest/cuda_vars.sh
echo "Hard-coding MLIR-compatible CUDA version (12.9)"
source /swtools/cuda/12.9.0/cuda_vars.sh
check_program nvcc
fi
4 changes: 0 additions & 4 deletions scripts/github/build_tpp.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,6 @@ echo "--- ENVIRONMENT"
if [ ! "${COMPILER}" ]; then
COMPILER=clang
fi
if [ "${COMPILER}" == "gcc" ]; then
echo "Hard-coding GCC to a known stable version (12.3)"
source /swtools/gcc/gcc-12.3.0/gcc_vars.sh
fi
if [ "${SANITIZERS}" ]; then
SANITIZERS="-S"
fi
Expand Down
9 changes: 4 additions & 5 deletions test/Integration/vector-contract-to-outerproduct.mlir
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
// RUN: tpp-opt %s | tpp-run -e entry --entry-point-result=void -seed 123 -print > %t.1
// RUN: tpp-opt %s --vector-contract-to-outerproduct | tpp-run -e entry --entry-point-result=void -seed 123 -print > %t.2
// RUN: diff %t.1 %t.2 | FileCheck %s --check-prefix=DIFF --allow-empty
// RUN: fpcmp -r 0.0001 %t.1 %t.2 | FileCheck %s --check-prefix=DIFF --allow-empty

// RUN: tpp-opt %s | tpp-run -e permA --entry-point-result=void -seed 123 -print > %t.1
// RUN: tpp-opt %s --vector-contract-to-outerproduct | tpp-run -e permA --entry-point-result=void -seed 123 -print > %t.2
// RUN: diff %t.1 %t.2 | FileCheck %s --check-prefix=DIFF-PERMA --allow-empty
// RUN: fpcmp -r 0.0001 %t.1 %t.2 | FileCheck %s --check-prefix=DIFF-PERMA --allow-empty

// RUN: tpp-opt %s | tpp-run -e permB --entry-point-result=void -seed 123 -print > %t.1
// RUN: tpp-opt %s --vector-contract-to-outerproduct | tpp-run -e permB --entry-point-result=void -seed 123 -print > %t.2
// RUN: diff %t.1 %t.2 | FileCheck %s --check-prefix=DIFF-PERMA --allow-empty
// RUN: fpcmp -r 0.0001 %t.1 %t.2 | FileCheck %s --check-prefix=DIFF-PERMA --allow-empty

// RUN: tpp-opt %s | tpp-run -e permAB --entry-point-result=void -seed 123 -print > %t.1
// RUN: tpp-opt %s --vector-contract-to-outerproduct | tpp-run -e permAB --entry-point-result=void -seed 123 -print > %t.2
// RUN: diff %t.1 %t.2 | FileCheck %s --check-prefix=DIFF-PERMAB --allow-empty

// RUN: fpcmp -r 0.0001 %t.1 %t.2 | FileCheck %s --check-prefix=DIFF-PERMAB --allow-empty

// DIFF-NOT: {{.}}
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
Expand Down
46 changes: 0 additions & 46 deletions test/Passes/DefaultPipeline/linalg-matmul-variants.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -24,52 +24,6 @@ func.func @matmul(%arg0: tensor<2048x2048xbf16>, %arg1: tensor<2048x2048xbf16>,

// -----

func.func @matmul_transpose_a(%arg0: tensor<2048x2048xbf16>, %arg1: tensor<2048x2048xbf16>, %arg2: tensor<2048x2048xbf16>)
-> tensor<2048x2048xbf16> {
%0 = linalg.matmul_transpose_a ins(%arg0, %arg1: tensor<2048x2048xbf16>, tensor<2048x2048xbf16>)
outs(%arg2: tensor<2048x2048xbf16>)
-> tensor<2048x2048xbf16>
return %0 : tensor<2048x2048xbf16>
}

// CHECK-LABEL: @matmul_transpose_a(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<2048x2048xbf16>,
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<2048x2048xbf16>,
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: memref<2048x2048xbf16>
// CHECK: memref.subview %[[ARG0]]
// CHECK: linalg.transpose
// CHECK: memref.subview %[[ARG1]]
// CHECK: call @xsmm_unary_invoke
// CHECK: memref.subview %[[ARG2]]
// CHECK: call @xsmm_intel_amx_tile_config_invoke
// CHECK: call @xsmm_brgemm_invoke
// CHECK: call @xsmm_intel_amx_tile_config_invoke

// -----

func.func @matmul_transpose_b(%arg0: tensor<2048x2048xbf16>, %arg1: tensor<2048x2048xbf16>, %arg2: tensor<2048x2048xbf16>)
-> tensor<2048x2048xbf16> {
%0 = linalg.matmul_transpose_b ins(%arg0, %arg1: tensor<2048x2048xbf16>, tensor<2048x2048xbf16>)
outs(%arg2: tensor<2048x2048xbf16>)
-> tensor<2048x2048xbf16>
return %0 : tensor<2048x2048xbf16>
}

// CHECK-LABEL: @matmul_transpose_b(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<2048x2048xbf16>,
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<2048x2048xbf16>,
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: memref<2048x2048xbf16>
// CHECK: memref.subview %[[ARG0]]
// CHECK: call @xsmm_unary_invoke
// CHECK: memref.subview %[[ARG1]]
// CHECK: linalg.transpose
// CHECK: memref.subview %[[ARG2]]
// CHECK: call @xsmm_intel_amx_tile_config_invoke
// CHECK: call @xsmm_brgemm_invoke
// CHECK: call @xsmm_intel_amx_tile_config_invoke

// -----

func.func @batch_matmul(%arg0: tensor<8x2048x2048xbf16>, %arg1: tensor<8x2048x2048xbf16>, %arg2: tensor<8x2048x2048xbf16>)
-> tensor<8x2048x2048xbf16> {
%0 = linalg.batch_matmul ins(%arg0, %arg1: tensor<8x2048x2048xbf16>, tensor<8x2048x2048xbf16>)
Expand Down
23 changes: 2 additions & 21 deletions test/Passes/fold-add-into-dest.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -36,40 +36,21 @@ func.func @expect_add_to_fold(%arg0: !type, %arg1: !type) -> !type {

// -----

!type = tensor<2048x2048xf32>
func.func @expect_add_to_fold(%arg0: !type, %arg1: !type) -> !type {
%0 = arith.constant dense<1.111111e+00> : !type
%cst = arith.constant 0.000000e+00 : f32
%1 = tensor.empty() : !type
%2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
%3 = linalg.matmul_transpose_a ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
%4 = linalg.matmul_transpose_b ins(%arg1, %0 : !type, !type) outs(%2 : !type) -> !type
%5 = linalg.add ins(%3, %4 : !type, !type) outs(%1 : !type) -> !type
return %5 : !type
}

// CHECK-LABEL: func.func @expect_add_to_fold
// CHECK: %[[ACC:.+]] = linalg.matmul_transpose_a
// CHECK-NEXT: %[[RES:.+]] = linalg.matmul_transpose_b ins(%[[X:.+]]) outs(%[[ACC]]
// CHECK-NEXT: return %[[RES]]

// -----

!type = tensor<2048x2048xf32>
func.func @expect_no_fold_as_operands_do_not_dominate_each_other(%arg0: !type, %arg1: !type) -> !type {
%0 = arith.constant dense<1.111111e+00> : !type
%cst = arith.constant 0.000000e+00 : f32
%1 = tensor.empty() : !type
%2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
%3 = linalg.matmul_transpose_b ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
%3 = linalg.matmul ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
%4 = linalg.add ins(%3, %3 : !type, !type) outs(%1 : !type) -> !type
return %4 : !type
}


// CHECK-LABEL: func.func @expect_no_fold_as_operands_do_not_dominate_each_other
// CHECK: linalg.fill
// CHECK-NEXT: linalg.matmul_transpose_b
// CHECK-NEXT: linalg.matmul
// CHECK-NEXT: linalg.add
// CHECK-NEXT: return

Expand Down
2 changes: 1 addition & 1 deletion test/Passes/pass-convert-gemm-to-parallel-tile.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ module {
// CHECK: %[[temp0:.*]] = call @xsmm_brgemm_dispatch(%[[c1_i64]], %[[c32_i64]], %[[c32_i64]], %[[c32_i64]], %[[c32_i64]], %[[c32_i64]], %[[c32_i64]], %[[c1024_i64]], %[[c1024_i64]], %[[c0_i64]])
// CHECK: omp.parallel {
// CHECK: omp.wsloop {
// CHECK: omp.loop_nest (%[[ARG3:.*]], %[[ARG4:.*]]) : index = (%[[c0]], %[[c0]]) to (%[[c8]], %[[c32]]) step (%[[c2]], %[[c8]]) {
// CHECK: omp.loop_nest (%[[ARG3:.*]], %[[ARG4:.*]]) : index = (%[[c0]], %[[c0]]) to (%[[c8]], %[[c32]]) step (%[[c2]], %[[c8]]) collapse(2) {
// CHECK: memref.alloca_scope {
// CHECK: scf.for %[[ARG5:.*]] = %[[c0]] to %[[c2]] step %[[c1]] {
// CHECK: %[[temp1:.*]] = arith.addi %[[ARG5]], %[[ARG3]] : index
Expand Down
6 changes: 3 additions & 3 deletions test/Passes/pass-convert-mlp-to-parallel-tile.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -80,23 +80,23 @@ module {
//CHECK: %[[temp0:.*]] = call @xsmm_fused_brgemm_dispatch(%[[c1_i64]], %[[c32_i64]], %[[c32_i64]], %[[c32_i64]], %[[c32_i64]], %[[c32_i64]], %[[c32_i64]], %[[c1024_i64]], %[[c1024_i64]], %[[c0_i64]], %[[c0_i64]], %[[c5_i64]], %[[c4_i64]], %[[c1_i64]])
//CHECK: omp.parallel {
//CHECK: omp.wsloop {
//CHECK: omp.loop_nest (%[[ARG10:.*]], %[[ARG11:.*]]) : index = (%[[c0]], %[[c0]]) to (%[[c8]], %[[c32]]) step (%[[c2]], %[[c16]]) {
//CHECK: omp.loop_nest (%[[ARG10:.*]], %[[ARG11:.*]]) : index = (%[[c0]], %[[c0]]) to (%[[c8]], %[[c32]]) step (%[[c2]], %[[c16]]) collapse(2) {
//CHECK: memref.alloca_scope {
//CHECK: scf.for %[[ARG12:.*]] = %[[c0]] to %[[c2]] step %[[c1]] {
//CHECK: %[[temp1:.*]] = arith.addi %[[ARG12]], %[[ARG10]] : index
//CHECK: scf.for %[[ARG13:.*]] = %[[c0]] to %[[c16]] step %[[c1]] {
//CHECK: %[[temp2:.*]] = arith.addi %[[ARG13]], %[[ARG11]] : index
//CHECK: omp.parallel {
//CHECK: omp.wsloop {
//CHECK: omp.loop_nest (%[[ARG10:.*]], %[[ARG11:.*]]) : index = (%[[c0]], %[[c0]]) to (%[[c8]], %[[c32]]) step (%[[c2]], %[[c16]]) {
//CHECK: omp.loop_nest (%[[ARG10:.*]], %[[ARG11:.*]]) : index = (%[[c0]], %[[c0]]) to (%[[c8]], %[[c32]]) step (%[[c2]], %[[c16]]) collapse(2) {
//CHECK: memref.alloca_scope {
//CHECK: scf.for %[[ARG12:.*]] = %[[c0]] to %[[c2]] step %[[c1]] {
//CHECK: %[[temp1:.*]] = arith.addi %[[ARG12]], %[[ARG10]] : index
//CHECK: scf.for %[[ARG13:.*]] = %[[c0]] to %[[c16]] step %[[c1]] {
//CHECK: %[[temp2:.*]] = arith.addi %[[ARG13]], %[[ARG11]] : index
//CHECK: omp.parallel {
//CHECK: omp.wsloop {
//CHECK: omp.loop_nest (%[[ARG10:.*]], %[[ARG11:.*]]) : index = (%[[c0]], %[[c0]]) to (%[[c8]], %[[c32]]) step (%[[c2]], %[[c16]]) {
//CHECK: omp.loop_nest (%[[ARG10:.*]], %[[ARG11:.*]]) : index = (%[[c0]], %[[c0]]) to (%[[c8]], %[[c32]]) step (%[[c2]], %[[c16]]) collapse(2) {
//CHECK: memref.alloca_scope {
//CHECK: scf.for %[[ARG12:.*]] = %[[c0]] to %[[c2]] step %[[c1]] {
//CHECK: %[[temp1:.*]] = arith.addi %[[ARG12]], %[[ARG10]] : index
Expand Down
58 changes: 0 additions & 58 deletions test/Passes/pass-matmul-blocking-default.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -29,64 +29,6 @@ func.func @block_linalg_matmul(

// -----

func.func @block_linalg_matmul_transpose_a(
%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>)
-> tensor<128x128xf32> {
%0 = linalg.matmul_transpose_a ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
outs(%arg2: tensor<128x128xf32>)
-> tensor<128x128xf32>
return %0 : tensor<128x128xf32>
}

// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
// CHECK-DAG: #[[MAP4:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)>
// CHECK-DAG: #[[MAP5:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>

// CHECK-LABEL: func @block_linalg_matmul_transpose_a(
// CHECK-SAME: %[[ARG0:[0-9a-z]+]]: tensor<128x128xf32>
// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<128x128xf32>
// CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<128x128xf32>) -> tensor<128x128xf32> {
// CHECK: %[[BUF0:.+]] = tensor.empty() : tensor<4x4x32x32xf32>
// CHECK: %[[PACK0:.+]] = linalg.pack %[[ARG0]] outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [32, 32] into %[[BUF0]] : tensor<128x128xf32> -> tensor<4x4x32x32xf32>
// CHECK: %[[BUF1:.*]] = tensor.empty() : tensor<4x4x32x32xf32>
// CHECK: %[[PACK1:.+]] = linalg.pack %[[ARG1]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[BUF1]] : tensor<128x128xf32> -> tensor<4x4x32x32xf32>
// CHECK: %[[BUF2:.+]] = tensor.empty() : tensor<4x4x32x32xf32>
// CHECK: %[[PACK2:.+]] = linalg.pack %[[ARG2]] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[BUF2]] : tensor<128x128xf32> -> tensor<4x4x32x32xf32>
// CHECK: %[[VAL:.+]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]], #[[MAP5]]], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%[[PACK0]], %[[PACK1]] : tensor<4x4x32x32xf32>, tensor<4x4x32x32xf32>) outs(%[[PACK2]] : tensor<4x4x32x32xf32>)
// CHECK: %[[OUT:.+]] = linalg.unpack %[[VAL]] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[ARG2]] : tensor<4x4x32x32xf32> -> tensor<128x128xf32>
// CHECK: return %[[OUT]] : tensor<128x128xf32>

// -----

func.func @block_linalg_matmul_transpose_b(
%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>)
-> tensor<128x128xf32> {
%0 = linalg.matmul_transpose_b ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
outs(%arg2: tensor<128x128xf32>)
-> tensor<128x128xf32>
return %0 : tensor<128x128xf32>
}

// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
// CHECK-DAG: #[[MAP4:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)>
// CHECK-DAG: #[[MAP5:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>

// CHECK-LABEL: func @block_linalg_matmul_transpose_b(
// CHECK-SAME: %[[ARG0:[0-9a-z]+]]: tensor<128x128xf32>
// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<128x128xf32>
// CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<128x128xf32>) -> tensor<128x128xf32> {
// CHECK: %[[BUF0:.+]] = tensor.empty() : tensor<4x4x32x32xf32>
// CHECK: %[[PACK0:.+]] = linalg.pack %[[ARG0]] outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[BUF0]] : tensor<128x128xf32> -> tensor<4x4x32x32xf32>
// CHECK: %[[BUF1:.*]] = tensor.empty() : tensor<4x4x32x32xf32>
// CHECK: %[[PACK1:.+]] = linalg.pack %[[ARG1]] outer_dims_perm = [0, 1] inner_dims_pos = [1, 0] inner_tiles = [32, 32] into %[[BUF1]] : tensor<128x128xf32> -> tensor<4x4x32x32xf32>
// CHECK: %[[BUF2:.+]] = tensor.empty() : tensor<4x4x32x32xf32>
// CHECK: %[[PACK2:.+]] = linalg.pack %[[ARG2]] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[BUF2]] : tensor<128x128xf32> -> tensor<4x4x32x32xf32>
// CHECK: %[[VAL:.+]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]], #[[MAP5]]], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%[[PACK0]], %[[PACK1]] : tensor<4x4x32x32xf32>, tensor<4x4x32x32xf32>) outs(%[[PACK2]] : tensor<4x4x32x32xf32>)
// CHECK: %[[OUT:.+]] = linalg.unpack %[[VAL]] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[ARG2]] : tensor<4x4x32x32xf32> -> tensor<128x128xf32>
// CHECK: return %[[OUT]] : tensor<128x128xf32>

// -----

func.func @block_linalg_matmul_dynamic(
%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>)
-> tensor<?x?xf32> {
Expand Down
Loading