Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

If you are not using a local iree build, install the iree pip packages:
```
pip install --find-links https://iree.dev/pip-release-links.html iree-base-compiler iree-base-runtime --upgrade
pip install --pre --find-links https://iree.dev/pip-release-links.html iree-base-compiler iree-base-runtime --upgrade
```

Create a python environment and install the requirements for the project:
Expand Down
100 changes: 39 additions & 61 deletions iree_kernel_benchmark/gemmbench/gemm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import os
import traceback
from iree.compiler import ir
from iree.compiler.dialects import arith, func, linalg, tensor


@dataclass
Expand Down Expand Up @@ -89,75 +90,52 @@ def generate_mlir(config: GemmConfig):
K = config.K
M = config.M
N = config.N
tA = config.tA
tB = config.tB

with ir.Location.name(config.get_name()):
operand_element_type = _convert_dtype_to_mlir(config.operand_element_type)
acc_element_type = _convert_dtype_to_mlir(config.accumulator_element_type)
result_element_type = _convert_dtype_to_mlir(config.result_element_type)
is_integer = isinstance(operand_element_type, ir.IntegerType)
literal_zero = ir.IntegerAttr.get(acc_element_type, 0) if is_integer else ir.FloatAttr.get(acc_element_type, 0.0)
K_M_operand_tensor_type = ir.RankedTensorType.get([K, M], operand_element_type)
M_K_operand_tensor_type = ir.RankedTensorType.get([M, K], operand_element_type)
K_N_operand_tensor_type = ir.RankedTensorType.get([K, N], operand_element_type)
N_K_operand_tensor_type = ir.RankedTensorType.get([N, K], operand_element_type)
M_N_acc_tensor_type = ir.RankedTensorType.get([M, N], acc_element_type)
M_N_result_tensor_type = ir.RankedTensorType.get([M, N], result_element_type)

trunc_op = "arith.trunci" if is_integer else "arith.truncf"

tA = config.tA
tB = config.tB
mlir_template_matmul_transpose_a = f"""
module {{
func.func @main(%arg0: {K_M_operand_tensor_type}, %arg1: {K_N_operand_tensor_type}) -> {M_N_result_tensor_type} {{
%cst = arith.constant {literal_zero}
%0 = tensor.empty() : {M_N_acc_tensor_type}
%1 = linalg.fill ins(%cst : {acc_element_type}) outs(%0 : {M_N_acc_tensor_type}) -> {M_N_acc_tensor_type}
%2 = linalg.matmul_transpose_a ins(%arg0, %arg1 : {K_M_operand_tensor_type}, {K_N_operand_tensor_type})
outs(%1 : {M_N_acc_tensor_type})
-> {M_N_acc_tensor_type}
"""

mlir_template_matmul_transpose_b = f"""
module {{
func.func @main(%arg0: {M_K_operand_tensor_type}, %arg1: {N_K_operand_tensor_type}) -> {M_N_result_tensor_type} {{
%cst = arith.constant {literal_zero}
%0 = tensor.empty() : {M_N_acc_tensor_type}
%1 = linalg.fill ins(%cst : {acc_element_type}) outs(%0 : {M_N_acc_tensor_type}) -> {M_N_acc_tensor_type}
%2 = linalg.matmul_transpose_b ins(%arg0, %arg1 : {M_K_operand_tensor_type}, {N_K_operand_tensor_type})
outs(%1 : {M_N_acc_tensor_type})
-> {M_N_acc_tensor_type}
"""

mlir_template_matmul_normal = f"""
module {{
func.func @main(%arg0: {M_K_operand_tensor_type}, %arg1: {K_N_operand_tensor_type}) -> {M_N_result_tensor_type} {{
%cst = arith.constant {literal_zero}
%0 = tensor.empty() : {M_N_acc_tensor_type}
%1 = linalg.fill ins(%cst : {acc_element_type}) outs(%0 : {M_N_acc_tensor_type}) -> {M_N_acc_tensor_type}
%2 = linalg.matmul ins(%arg0, %arg1 : {M_K_operand_tensor_type}, {K_N_operand_tensor_type})
outs(%1 : {M_N_acc_tensor_type})
-> {M_N_acc_tensor_type}
"""
mlir_template_matmul = mlir_template_matmul_transpose_a if tA == "T" else mlir_template_matmul_transpose_b if tB == "T" else mlir_template_matmul_normal

mlir_template_return_truncated = f"""
%3 = {trunc_op} %2 : {M_N_acc_tensor_type} to {M_N_result_tensor_type}
return %3 : {M_N_result_tensor_type}
}}
}}
"""

mlir_template_return_untruncated = f"""
return %2 : {M_N_result_tensor_type}
}}
}}
"""

mlir_template_return = mlir_template_return_untruncated if (acc_element_type == result_element_type) else mlir_template_return_truncated

return mlir_template_matmul + mlir_template_return

# Transpose A
if tA == "T":
arg0_type = ir.RankedTensorType.get([K, M], operand_element_type)
arg1_type = ir.RankedTensorType.get([K, N], operand_element_type)
# Transpose B
elif tB == "T":
arg0_type = ir.RankedTensorType.get([M, K], operand_element_type)
arg1_type = ir.RankedTensorType.get([N, K], operand_element_type)
# "Normal" path (can't transpose both)
else:
assert tA == "N" and tB == "N"
arg0_type = ir.RankedTensorType.get([M, K], operand_element_type)
arg1_type = ir.RankedTensorType.get([K, N], operand_element_type)
result_type = ir.RankedTensorType.get([M, N], result_element_type)

module = ir.Module.create()
with ir.InsertionPoint(module.body):
@func.FuncOp.from_py_func(arg0_type, arg1_type)
def main(arg0, arg1):
zero_element = arith.constant(value = literal_zero, result = acc_element_type)
empty_tensor = tensor.empty(element_type = acc_element_type, sizes = [M, N])
filled_tensor = linalg.fill(zero_element, outs = [empty_tensor])

if tA == "T":
acc = linalg.matmul_transpose_a(arg0, arg1, outs = [filled_tensor])
elif tB == "T":
acc = linalg.matmul_transpose_b(arg0, arg1, outs = [filled_tensor])
else:
acc = linalg.matmul(arg0, arg1, outs = [filled_tensor])

if acc_element_type == result_element_type:
return acc
if is_integer:
return arith.trunci(result_type, acc)
return arith.truncf(result_type, acc)
return f"{module}"

@dataclass
class TkTunedConfig:
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ tqdm
matplotlib
torch>=2.3.0
pytest>=8.3.5
PyYAML>=6.0.2 # Required by iree.compiler.dialects.linalg
28 changes: 8 additions & 20 deletions tests/test_gemmbench_mlir_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@ def test_n_t_f16_f32_f16():
"%cst = arith.constant 0.000000e+00 : f32",
"%0 = tensor.empty() : tensor<512x4096xf32>",
"%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<512x4096xf32>) -> tensor<512x4096xf32>",
"%2 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<512x14336xf16>, tensor<4096x14336xf16>)",
"outs(%1 : tensor<512x4096xf32>)",
"-> tensor<512x4096xf32>",
"%2 = linalg.matmul_transpose_b {cast = #linalg.type_fn<cast_signed>} ins(%arg0, %arg1 : tensor<512x14336xf16>, tensor<4096x14336xf16>) outs(%1 : tensor<512x4096xf32>) -> tensor<512x4096xf32>",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there some constructor that doesn't produce these cast attributes? I think it's fine as-is and being explicit doesn't hurt, but it does look odd since nothing else in IREE produces these

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll investigate it, I'm curious myself as to what this means.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it can be done with the constructor currently in use here, it takes a TypeFn for the cast argument and that can be either cast_signed or cast_unsigned, there's no other options (I tried None and creating an empty TypeFn, neither were accepted). I tried doing this instead:

acc = linalg.MatmulTransposeAOp(inputs = [arg0, arg1], outputs = [filled_tensor], result_tensors = [acc_type])

But it seems like that produces an invalid operation; the IR prints weirdly after I do that:

%3 = "linalg.matmul_transpose_a"(%arg0, %arg1, %2) <{operandSegmentSizes = array<i32: 2, 1>}> ({
    }) {linalg.memoized_indexing_maps = [#map, #map1, #map2]} : (tensor<5120x32000xbf16>, tensor<5120x1xbf16>, tensor<32000x1xf32>) -> tensor<32000x1xf32>

And I get this error if I call .verify():

E           iree.compiler._mlir_libs._site_initialize.<locals>.MLIRError: Verification failed:
E           error: "gemm_32000_1_5120_f16_f32_tA": 'linalg.matmul_transpose_a' op expects to have 1 region with 1 block
E            note: "gemm_32000_1_5120_f16_f32_tA": see current operation:
E             %3 = "linalg.matmul_transpose_a"(%arg0, %arg1, %2) <{operandSegmentSizes = array<i32: 2, 1>}> ({
E             }) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]} : (tensor<5120x32000xf16>, tensor<5120x1xf16>, tensor<32000x1xf32>) -> tensor<32000x1xf32>

There's probably some way to repair it but I'm not sure if it's worth the effort. Maybe I can ask for a second opinion from someone who knows linalg better?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, ask on IREE discord.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we file an issue against mlir (upstream)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't feel strongly either way on that, maybe @rkayaith or @makslevental have an opinion there?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the python bindings are doing anything wrong necessarily, the C++ side defaults to cast_signed as well:

// LinalgNamedStructuredOps.yamlgen.td

def MatmulTransposeBOp : LinalgStructuredBase_Op<"matmul_transpose_b", !listconcat([AttrSizedOperandSegments],
  /*extraInterfaces=*/[LinalgContractionOpInterface])> {
    ...
    let arguments = (ins
      Variadic<AnyType>:$inputs,
      Variadic<AnyShaped>:$outputs,
DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
    );
    ...

But unfortunately using DefaultValuedOptionalAttr causes the attribute to only be elided in the asm format when the attribute is missing. IMO using DefaultValuedAttr instead (no Optional) + updating the custom asm format to elide the default would be a good solution here, but I don't know if there'd be any repercussions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I created llvm/llvm-project#132961 now. I am not very confident I have described the issue well, so if anyone thinks they can add something by commenting there, I would appreciate it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you cc the people who made this change? Otherwise they may not ever see this issue.

"%3 = arith.truncf %2 : tensor<512x4096xf32> to tensor<512x4096xf16>",
"return %3 : tensor<512x4096xf16>",
],
Expand Down Expand Up @@ -64,9 +62,7 @@ def test_n_t_bf16_f32_bf16():
"%cst = arith.constant 0.000000e+00 : f32",
"%0 = tensor.empty() : tensor<2x1280xf32>",
"%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x1280xf32>) -> tensor<2x1280xf32>",
"%2 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<2x8192xbf16>, tensor<1280x8192xbf16>)",
"outs(%1 : tensor<2x1280xf32>)",
"-> tensor<2x1280xf32>",
"%2 = linalg.matmul_transpose_b {cast = #linalg.type_fn<cast_signed>} ins(%arg0, %arg1 : tensor<2x8192xbf16>, tensor<1280x8192xbf16>) outs(%1 : tensor<2x1280xf32>) -> tensor<2x1280xf32>",
"%3 = arith.truncf %2 : tensor<2x1280xf32> to tensor<2x1280xbf16>",
"return %3 : tensor<2x1280xbf16>",
],
Expand Down Expand Up @@ -94,9 +90,7 @@ def test_t_n_f16_f32_f16():
"%cst = arith.constant 0.000000e+00 : f32",
"%0 = tensor.empty() : tensor<32000x1xf32>",
"%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<32000x1xf32>) -> tensor<32000x1xf32>",
"%2 = linalg.matmul_transpose_a ins(%arg0, %arg1 : tensor<5120x32000xf16>, tensor<5120x1xf16>)",
"outs(%1 : tensor<32000x1xf32>)",
"-> tensor<32000x1xf32>",
"%2 = linalg.matmul_transpose_a {cast = #linalg.type_fn<cast_signed>} ins(%arg0, %arg1 : tensor<5120x32000xf16>, tensor<5120x1xf16>) outs(%1 : tensor<32000x1xf32>) -> tensor<32000x1xf32>",
"%3 = arith.truncf %2 : tensor<32000x1xf32> to tensor<32000x1xf16>",
"return %3 : tensor<32000x1xf16>",
],
Expand Down Expand Up @@ -124,9 +118,7 @@ def test_t_n_bf16_f32_bf16():
"%cst = arith.constant 0.000000e+00 : f32",
"%0 = tensor.empty() : tensor<32000x1xf32>",
"%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<32000x1xf32>) -> tensor<32000x1xf32>",
"%2 = linalg.matmul_transpose_a ins(%arg0, %arg1 : tensor<5120x32000xbf16>, tensor<5120x1xbf16>)",
"outs(%1 : tensor<32000x1xf32>)",
"-> tensor<32000x1xf32>",
"%2 = linalg.matmul_transpose_a {cast = #linalg.type_fn<cast_signed>} ins(%arg0, %arg1 : tensor<5120x32000xbf16>, tensor<5120x1xbf16>) outs(%1 : tensor<32000x1xf32>) -> tensor<32000x1xf32>",
"%3 = arith.truncf %2 : tensor<32000x1xf32> to tensor<32000x1xbf16>",
"return %3 : tensor<32000x1xbf16>",
],
Expand Down Expand Up @@ -154,9 +146,7 @@ def test_n_n_f16_f32_f16():
"%cst = arith.constant 0.000000e+00 : f32",
"%0 = tensor.empty() : tensor<2048x2048xf32>",
"%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2048x2048xf32>) -> tensor<2048x2048xf32>",
"%2 = linalg.matmul ins(%arg0, %arg1 : tensor<2048x1024xf16>, tensor<1024x2048xf16>)",
"outs(%1 : tensor<2048x2048xf32>)",
"-> tensor<2048x2048xf32>",
"%2 = linalg.matmul ins(%arg0, %arg1 : tensor<2048x1024xf16>, tensor<1024x2048xf16>) outs(%1 : tensor<2048x2048xf32>) -> tensor<2048x2048xf32>",
"%3 = arith.truncf %2 : tensor<2048x2048xf32> to tensor<2048x2048xf16>",
"return %3 : tensor<2048x2048xf16>",
],
Expand All @@ -181,12 +171,10 @@ def test_n_t_i8_i32_i8():
[
"module {",
"func.func @main(%arg0: tensor<128x128xi8>, %arg1: tensor<128x128xi8>) -> tensor<128x128xi8> {",
"%cst = arith.constant 0 : i32",
"%c0_i32 = arith.constant 0 : i32",
"%0 = tensor.empty() : tensor<128x128xi32>",
"%1 = linalg.fill ins(%cst : i32) outs(%0 : tensor<128x128xi32>) -> tensor<128x128xi32>",
"%2 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<128x128xi8>, tensor<128x128xi8>)",
"outs(%1 : tensor<128x128xi32>)",
"-> tensor<128x128xi32>",
"%1 = linalg.fill ins(%c0_i32 : i32) outs(%0 : tensor<128x128xi32>) -> tensor<128x128xi32>",
"%2 = linalg.matmul_transpose_b {cast = #linalg.type_fn<cast_signed>} ins(%arg0, %arg1 : tensor<128x128xi8>, tensor<128x128xi8>) outs(%1 : tensor<128x128xi32>) -> tensor<128x128xi32>",
"%3 = arith.trunci %2 : tensor<128x128xi32> to tensor<128x128xi8>",
"return %3 : tensor<128x128xi8>",
],
Expand Down
Loading