@@ -664,17 +664,15 @@ def mlir_converter_permute():
664664 @wave .wave (permute_constraints )
665665 def permute_kernel (
666666 a : Memory [M , N , ADDRESS_SPACE_A , tkl .f16 ],
667- c : Memory [N , M , ADDRESS_SPACE_C , tkl .f32 ],
667+ c : Memory [N , M , ADDRESS_SPACE_C , tkl .f16 ],
668668 ):
669669 # Load values from memory into registers
670- a_reg = wave .read (a )
670+ a_reg = wave .read (a , elements_per_thread = 1 )
671671
672672 # Permute dimensions from [M, N] to [N, M]
673673 permuted = wave .permute (a_reg , target_shape = [N , M ])
674674
675- # Cast and write results back to memory
676- permuted_f32 = wave .cast (permuted , tkl .f32 )
677- wave .write (permuted_f32 , c )
675+ wave .write (permuted , c , elements_per_thread = 1 )
678676
679677 # Set parameters for compilation
680678 subs = {
@@ -713,20 +711,13 @@ def permute_kernel(
713711 print (mlir_output )
714712
715713 # CHECK-LABEL: mlir_converter_permute
716- # CHECK: func.func @kernel(%[[ARG0:.*]]: !wave.tensor<[@M, @N] of f16, <global>>, %[[ARG1:.*]]: !wave.tensor<[@N, @M] of f32 , <global>>)
714+ # CHECK: func.func @kernel(%[[ARG0:.*]]: !wave.tensor<[@M, @N] of f16, <global>>, %[[ARG1:.*]]: !wave.tensor<[@N, @M] of f16 , <global>>)
717715
718716 # CHECK: %[[READ:.*]] = wave.read %[[ARG0]]
719717 # CHECK-SAME: (!wave.tensor<[@M, @N] of f16, <global>>) -> !wave.tensor<[@M, @N] of f16, <register>>
720718
721719 # CHECK: %[[PERMUTE:.*]] = wave.permute %[[READ]]
722720 # CHECK-SAME: !wave.tensor<[@M, @N] of f16, <register>> to !wave.tensor<[@N, @M] of f16, <register>>
723721
724- # CHECK: %[[CAST:.*]] = wave.cast %[[PERMUTE]]
725- # CHECK-SAME: : !wave.tensor<[@N, @M] of f16, <register>> to !wave.tensor<[@N, @M] of f32, <register>>
726-
727- # The permuted write has non-contiguous access, so it gets partitioned into
728- # multiple extract_slice + write pairs. Check that the first one has correct types.
729- # CHECK: %[[SLICE:.*]] = wave.extract_slice %[[CAST]]
730- # CHECK-SAME: (!wave.tensor<[@N, @M] of f32, <register>>) -> !wave.tensor<[@N, @M] of f32, <register>>
731- # CHECK: wave.write %[[SLICE]], %[[ARG1]]
732- # CHECK-SAME: !wave.tensor<[@N, @M] of f32, <register>>, !wave.tensor<[@N, @M] of f32, <global>>
722+ # CHECK: wave.write %[[PERMUTE]], %[[ARG1]]
723+ # CHECK-SAME: !wave.tensor<[@N, @M] of f16, <register>>, !wave.tensor<[@N, @M] of f16, <global>>
0 commit comments