diff --git a/README.md b/README.md index 2636a64..5b8d3e4 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ If you are not using a local iree build, install the iree pip packages: ``` -pip install --find-links https://iree.dev/pip-release-links.html iree-base-compiler iree-base-runtime --upgrade +pip install --pre --find-links https://iree.dev/pip-release-links.html iree-base-compiler iree-base-runtime --upgrade ``` Create a python environment and install the requirements for the project: diff --git a/iree_kernel_benchmark/gemmbench/gemm_utils.py b/iree_kernel_benchmark/gemmbench/gemm_utils.py index 8d3efdc..46954a4 100644 --- a/iree_kernel_benchmark/gemmbench/gemm_utils.py +++ b/iree_kernel_benchmark/gemmbench/gemm_utils.py @@ -19,6 +19,7 @@ import os import traceback from iree.compiler import ir +from iree.compiler.dialects import arith, func, linalg, tensor @dataclass @@ -89,6 +90,8 @@ def generate_mlir(config: GemmConfig): K = config.K M = config.M N = config.N + tA = config.tA + tB = config.tB with ir.Location.name(config.get_name()): operand_element_type = _convert_dtype_to_mlir(config.operand_element_type) @@ -96,68 +99,43 @@ def generate_mlir(config: GemmConfig): result_element_type = _convert_dtype_to_mlir(config.result_element_type) is_integer = isinstance(operand_element_type, ir.IntegerType) literal_zero = ir.IntegerAttr.get(acc_element_type, 0) if is_integer else ir.FloatAttr.get(acc_element_type, 0.0) - K_M_operand_tensor_type = ir.RankedTensorType.get([K, M], operand_element_type) - M_K_operand_tensor_type = ir.RankedTensorType.get([M, K], operand_element_type) - K_N_operand_tensor_type = ir.RankedTensorType.get([K, N], operand_element_type) - N_K_operand_tensor_type = ir.RankedTensorType.get([N, K], operand_element_type) - M_N_acc_tensor_type = ir.RankedTensorType.get([M, N], acc_element_type) - M_N_result_tensor_type = ir.RankedTensorType.get([M, N], result_element_type) - - trunc_op = "arith.trunci" if is_integer else "arith.truncf" - - tA = config.tA - tB = config.tB - mlir_template_matmul_transpose_a = f""" -module {{ - func.func @main(%arg0: {K_M_operand_tensor_type}, %arg1: {K_N_operand_tensor_type}) -> {M_N_result_tensor_type} {{ - %cst = arith.constant {literal_zero} - %0 = tensor.empty() : {M_N_acc_tensor_type} - %1 = linalg.fill ins(%cst : {acc_element_type}) outs(%0 : {M_N_acc_tensor_type}) -> {M_N_acc_tensor_type} - %2 = linalg.matmul_transpose_a ins(%arg0, %arg1 : {K_M_operand_tensor_type}, {K_N_operand_tensor_type}) - outs(%1 : {M_N_acc_tensor_type}) - -> {M_N_acc_tensor_type} -""" - - mlir_template_matmul_transpose_b = f""" -module {{ - func.func @main(%arg0: {M_K_operand_tensor_type}, %arg1: {N_K_operand_tensor_type}) -> {M_N_result_tensor_type} {{ - %cst = arith.constant {literal_zero} - %0 = tensor.empty() : {M_N_acc_tensor_type} - %1 = linalg.fill ins(%cst : {acc_element_type}) outs(%0 : {M_N_acc_tensor_type}) -> {M_N_acc_tensor_type} - %2 = linalg.matmul_transpose_b ins(%arg0, %arg1 : {M_K_operand_tensor_type}, {N_K_operand_tensor_type}) - outs(%1 : {M_N_acc_tensor_type}) - -> {M_N_acc_tensor_type} -""" - - mlir_template_matmul_normal = f""" -module {{ - func.func @main(%arg0: {M_K_operand_tensor_type}, %arg1: {K_N_operand_tensor_type}) -> {M_N_result_tensor_type} {{ - %cst = arith.constant {literal_zero} - %0 = tensor.empty() : {M_N_acc_tensor_type} - %1 = linalg.fill ins(%cst : {acc_element_type}) outs(%0 : {M_N_acc_tensor_type}) -> {M_N_acc_tensor_type} - %2 = linalg.matmul ins(%arg0, %arg1 : {M_K_operand_tensor_type}, {K_N_operand_tensor_type}) - outs(%1 : {M_N_acc_tensor_type}) - -> {M_N_acc_tensor_type} -""" - mlir_template_matmul = mlir_template_matmul_transpose_a if tA == "T" else mlir_template_matmul_transpose_b if tB == "T" else mlir_template_matmul_normal - - mlir_template_return_truncated = f""" - %3 = {trunc_op} %2 : {M_N_acc_tensor_type} to {M_N_result_tensor_type} - return %3 : {M_N_result_tensor_type} - }} -}} -""" - - mlir_template_return_untruncated = f""" - return %2 : {M_N_result_tensor_type} - }} -}} -""" - - mlir_template_return = mlir_template_return_untruncated if (acc_element_type == result_element_type) else mlir_template_return_truncated - - return mlir_template_matmul + mlir_template_return + # Transpose A + if tA == "T": + arg0_type = ir.RankedTensorType.get([K, M], operand_element_type) + arg1_type = ir.RankedTensorType.get([K, N], operand_element_type) + # Transpose B + elif tB == "T": + arg0_type = ir.RankedTensorType.get([M, K], operand_element_type) + arg1_type = ir.RankedTensorType.get([N, K], operand_element_type) + # "Normal" path (can't transpose both) + else: + assert tA == "N" and tB == "N" + arg0_type = ir.RankedTensorType.get([M, K], operand_element_type) + arg1_type = ir.RankedTensorType.get([K, N], operand_element_type) + result_type = ir.RankedTensorType.get([M, N], result_element_type) + + module = ir.Module.create() + with ir.InsertionPoint(module.body): + @func.FuncOp.from_py_func(arg0_type, arg1_type) + def main(arg0, arg1): + zero_element = arith.constant(value = literal_zero, result = acc_element_type) + empty_tensor = tensor.empty(element_type = acc_element_type, sizes = [M, N]) + filled_tensor = linalg.fill(zero_element, outs = [empty_tensor]) + + if tA == "T": + acc = linalg.matmul_transpose_a(arg0, arg1, outs = [filled_tensor]) + elif tB == "T": + acc = linalg.matmul_transpose_b(arg0, arg1, outs = [filled_tensor]) + else: + acc = linalg.matmul(arg0, arg1, outs = [filled_tensor]) + + if acc_element_type == result_element_type: + return acc + if is_integer: + return arith.trunci(result_type, acc) + return arith.truncf(result_type, acc) + return f"{module}" @dataclass class TkTunedConfig: diff --git a/requirements.txt b/requirements.txt index d85bd87..bad4398 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,4 @@ tqdm matplotlib torch>=2.3.0 pytest>=8.3.5 +PyYAML>=6.0.2 # Required by iree.compiler.dialects.linalg diff --git a/tests/test_gemmbench_mlir_gen.py b/tests/test_gemmbench_mlir_gen.py index 1592d64..414ea65 100644 --- a/tests/test_gemmbench_mlir_gen.py +++ b/tests/test_gemmbench_mlir_gen.py @@ -34,9 +34,7 @@ def test_n_t_f16_f32_f16(): "%cst = arith.constant 0.000000e+00 : f32", "%0 = tensor.empty() : tensor<512x4096xf32>", "%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<512x4096xf32>) -> tensor<512x4096xf32>", - "%2 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<512x14336xf16>, tensor<4096x14336xf16>)", - "outs(%1 : tensor<512x4096xf32>)", - "-> tensor<512x4096xf32>", + "%2 = linalg.matmul_transpose_b {cast = #linalg.type_fn} ins(%arg0, %arg1 : tensor<512x14336xf16>, tensor<4096x14336xf16>) outs(%1 : tensor<512x4096xf32>) -> tensor<512x4096xf32>", "%3 = arith.truncf %2 : tensor<512x4096xf32> to tensor<512x4096xf16>", "return %3 : tensor<512x4096xf16>", ], @@ -64,9 +62,7 @@ def test_n_t_bf16_f32_bf16(): "%cst = arith.constant 0.000000e+00 : f32", "%0 = tensor.empty() : tensor<2x1280xf32>", "%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x1280xf32>) -> tensor<2x1280xf32>", - "%2 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<2x8192xbf16>, tensor<1280x8192xbf16>)", - "outs(%1 : tensor<2x1280xf32>)", - "-> tensor<2x1280xf32>", + "%2 = linalg.matmul_transpose_b {cast = #linalg.type_fn} ins(%arg0, %arg1 : tensor<2x8192xbf16>, tensor<1280x8192xbf16>) outs(%1 : tensor<2x1280xf32>) -> tensor<2x1280xf32>", "%3 = arith.truncf %2 : tensor<2x1280xf32> to tensor<2x1280xbf16>", "return %3 : tensor<2x1280xbf16>", ], @@ -94,9 +90,7 @@ def test_t_n_f16_f32_f16(): "%cst = arith.constant 0.000000e+00 : f32", "%0 = tensor.empty() : tensor<32000x1xf32>", "%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<32000x1xf32>) -> tensor<32000x1xf32>", - "%2 = linalg.matmul_transpose_a ins(%arg0, %arg1 : tensor<5120x32000xf16>, tensor<5120x1xf16>)", - "outs(%1 : tensor<32000x1xf32>)", - "-> tensor<32000x1xf32>", + "%2 = linalg.matmul_transpose_a {cast = #linalg.type_fn} ins(%arg0, %arg1 : tensor<5120x32000xf16>, tensor<5120x1xf16>) outs(%1 : tensor<32000x1xf32>) -> tensor<32000x1xf32>", "%3 = arith.truncf %2 : tensor<32000x1xf32> to tensor<32000x1xf16>", "return %3 : tensor<32000x1xf16>", ], @@ -124,9 +118,7 @@ def test_t_n_bf16_f32_bf16(): "%cst = arith.constant 0.000000e+00 : f32", "%0 = tensor.empty() : tensor<32000x1xf32>", "%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<32000x1xf32>) -> tensor<32000x1xf32>", - "%2 = linalg.matmul_transpose_a ins(%arg0, %arg1 : tensor<5120x32000xbf16>, tensor<5120x1xbf16>)", - "outs(%1 : tensor<32000x1xf32>)", - "-> tensor<32000x1xf32>", + "%2 = linalg.matmul_transpose_a {cast = #linalg.type_fn} ins(%arg0, %arg1 : tensor<5120x32000xbf16>, tensor<5120x1xbf16>) outs(%1 : tensor<32000x1xf32>) -> tensor<32000x1xf32>", "%3 = arith.truncf %2 : tensor<32000x1xf32> to tensor<32000x1xbf16>", "return %3 : tensor<32000x1xbf16>", ], @@ -154,9 +146,7 @@ def test_n_n_f16_f32_f16(): "%cst = arith.constant 0.000000e+00 : f32", "%0 = tensor.empty() : tensor<2048x2048xf32>", "%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2048x2048xf32>) -> tensor<2048x2048xf32>", - "%2 = linalg.matmul ins(%arg0, %arg1 : tensor<2048x1024xf16>, tensor<1024x2048xf16>)", - "outs(%1 : tensor<2048x2048xf32>)", - "-> tensor<2048x2048xf32>", + "%2 = linalg.matmul ins(%arg0, %arg1 : tensor<2048x1024xf16>, tensor<1024x2048xf16>) outs(%1 : tensor<2048x2048xf32>) -> tensor<2048x2048xf32>", "%3 = arith.truncf %2 : tensor<2048x2048xf32> to tensor<2048x2048xf16>", "return %3 : tensor<2048x2048xf16>", ], @@ -181,12 +171,10 @@ def test_n_t_i8_i32_i8(): [ "module {", "func.func @main(%arg0: tensor<128x128xi8>, %arg1: tensor<128x128xi8>) -> tensor<128x128xi8> {", - "%cst = arith.constant 0 : i32", + "%c0_i32 = arith.constant 0 : i32", "%0 = tensor.empty() : tensor<128x128xi32>", - "%1 = linalg.fill ins(%cst : i32) outs(%0 : tensor<128x128xi32>) -> tensor<128x128xi32>", - "%2 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<128x128xi8>, tensor<128x128xi8>)", - "outs(%1 : tensor<128x128xi32>)", - "-> tensor<128x128xi32>", + "%1 = linalg.fill ins(%c0_i32 : i32) outs(%0 : tensor<128x128xi32>) -> tensor<128x128xi32>", + "%2 = linalg.matmul_transpose_b {cast = #linalg.type_fn} ins(%arg0, %arg1 : tensor<128x128xi8>, tensor<128x128xi8>) outs(%1 : tensor<128x128xi32>) -> tensor<128x128xi32>", "%3 = arith.trunci %2 : tensor<128x128xi32> to tensor<128x128xi8>", "return %3 : tensor<128x128xi8>", ],