Skip to content

Commit 3116fcf

Browse files
committed
fix mlir_exporter
Signed-off-by: Tim Gymnich <tim@gymni.ch>
1 parent 3b5591a commit 3116fcf

File tree

1 file changed

+6
-15
lines changed

1 file changed

+6
-15
lines changed

lit_tests/kernel/wave/mlir_converter.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -664,17 +664,15 @@ def mlir_converter_permute():
664664
@wave.wave(permute_constraints)
665665
def permute_kernel(
666666
a: Memory[M, N, ADDRESS_SPACE_A, tkl.f16],
667-
c: Memory[N, M, ADDRESS_SPACE_C, tkl.f32],
667+
c: Memory[N, M, ADDRESS_SPACE_C, tkl.f16],
668668
):
669669
# Load values from memory into registers
670-
a_reg = wave.read(a)
670+
a_reg = wave.read(a, elements_per_thread=1)
671671

672672
# Permute dimensions from [M, N] to [N, M]
673673
permuted = wave.permute(a_reg, target_shape=[N, M])
674674

675-
# Cast and write results back to memory
676-
permuted_f32 = wave.cast(permuted, tkl.f32)
677-
wave.write(permuted_f32, c)
675+
wave.write(permuted, c, elements_per_thread=1)
678676

679677
# Set parameters for compilation
680678
subs = {
@@ -713,20 +711,13 @@ def permute_kernel(
713711
print(mlir_output)
714712

715713
# CHECK-LABEL: mlir_converter_permute
716-
# CHECK: func.func @kernel(%[[ARG0:.*]]: !wave.tensor<[@M, @N] of f16, <global>>, %[[ARG1:.*]]: !wave.tensor<[@N, @M] of f32, <global>>)
714+
# CHECK: func.func @kernel(%[[ARG0:.*]]: !wave.tensor<[@M, @N] of f16, <global>>, %[[ARG1:.*]]: !wave.tensor<[@N, @M] of f16, <global>>)
717715

718716
# CHECK: %[[READ:.*]] = wave.read %[[ARG0]]
719717
# CHECK-SAME: (!wave.tensor<[@M, @N] of f16, <global>>) -> !wave.tensor<[@M, @N] of f16, <register>>
720718

721719
# CHECK: %[[PERMUTE:.*]] = wave.permute %[[READ]]
722720
# CHECK-SAME: !wave.tensor<[@M, @N] of f16, <register>> to !wave.tensor<[@N, @M] of f16, <register>>
723721

724-
# CHECK: %[[CAST:.*]] = wave.cast %[[PERMUTE]]
725-
# CHECK-SAME: : !wave.tensor<[@N, @M] of f16, <register>> to !wave.tensor<[@N, @M] of f32, <register>>
726-
727-
# The permuted write has non-contiguous access, so it gets partitioned into
728-
# multiple extract_slice + write pairs. Check that the first one has correct types.
729-
# CHECK: %[[SLICE:.*]] = wave.extract_slice %[[CAST]]
730-
# CHECK-SAME: (!wave.tensor<[@N, @M] of f32, <register>>) -> !wave.tensor<[@N, @M] of f32, <register>>
731-
# CHECK: wave.write %[[SLICE]], %[[ARG1]]
732-
# CHECK-SAME: !wave.tensor<[@N, @M] of f32, <register>>, !wave.tensor<[@N, @M] of f32, <global>>
722+
# CHECK: wave.write %[[PERMUTE]], %[[ARG1]]
723+
# CHECK-SAME: !wave.tensor<[@N, @M] of f16, <register>>, !wave.tensor<[@N, @M] of f16, <global>>

0 commit comments

Comments
 (0)