@@ -797,3 +797,123 @@ tt.func @pipeline_fp64_with_async_copy_gfx950(
797
797
tt.return %loop: tensor <128 x128 xf64 , #C >
798
798
}
799
799
}
800
+
801
+ // -----
802
+
803
+ // COMMON-LABEL: pipelining_local_load_packed_transposed
804
+
805
+ // Prologue
806
+ // COMMON: ttg.local_alloc
807
+ // COMMON: ttg.local_alloc
808
+ // ASYNC: ttg.async_copy_global_to_local
809
+ // SYNC: tt.load
810
+ // COMMON: tt.load
811
+ // SYNC: ttg.local_store
812
+ // COMMON: ttg.local_store
813
+
814
+ // Main loop
815
+ // COMMON: scf.for
816
+ // COMMON: ttg.local_load
817
+ // COMMON: amdgpu.local_load_packed_tranposed
818
+ // COMMON: tt.dot_scaled
819
+ // COMMON: scf.yield
820
+
821
+ // Epilogue
822
+ // COMMON: ttg.local_load
823
+ // COMMON: amdgpu.local_load_packed_tranposed
824
+ // COMMON: scf.if
825
+ // COMMON: tt.dot_scaled
826
+ // COMMON-COUNT-2: scf.yield
827
+ // COMMON-COUNT-2: ttg.local_dealloc
828
+
829
+ #blocked = #ttg.blocked <{sizePerThread = [1 , 16 ], threadsPerWarp = [8 , 8 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
830
+ #blocked1 = #ttg.blocked <{sizePerThread = [16 , 1 ], threadsPerWarp = [8 , 8 ], warpsPerCTA = [1 , 4 ], order = [0 , 1 ]}>
831
+ #blocked2 = #ttg.blocked <{sizePerThread = [1 , 4 ], threadsPerWarp = [2 , 32 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
832
+ #mma = #ttg.amd_mfma <{version = 4 , warpsPerCTA = [2 , 2 ], instrShape = [32 , 32 ], isTransposed = true }>
833
+ #shared = #ttg.swizzled_shared <{vec = 16 , perPhase = 4 , maxPhase = 4 , order = [1 , 0 ]}>
834
+ #smem = #ttg.shared_memory
835
+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , ttg.target = " hip:gfx950" , " ttg.threads-per-warp" = 64 : i32 } {
836
+ tt.func public @pipelining_local_load_packed_transposed (%a_ptr: !tt.ptr <f8E5M2 > {tt.divisibility = 16 : i32 , tt.pointer_range = 32 : i32 }, %b_ptr: !tt.ptr <i8 > {tt.divisibility = 16 : i32 , tt.pointer_range = 32 : i32 }, %output_ptr: !tt.ptr <f32 > {tt.divisibility = 16 : i32 , tt.pointer_range = 32 : i32 }, %M: i32 {tt.divisibility = 16 : i32 }, %N: i32 {tt.divisibility = 16 : i32 }, %K: i32 {tt.divisibility = 16 : i32 }, %stride_scale: i32 {tt.divisibility = 16 : i32 }, %stride_am: i32 {tt.divisibility = 16 : i32 }, %stride_bn: i32 {tt.divisibility = 16 : i32 }, %stride_cm: i32 {tt.divisibility = 16 : i32 }) attributes {noinline = false } {
837
+ %cst = arith.constant dense <128 > : tensor <128 x128 xi32 , #blocked >
838
+ %cst_0 = arith.constant dense <128 > : tensor <128 x64 xi32 , #blocked1 >
839
+ %c0_i32 = arith.constant 0 : i32
840
+ %c1_i32 = arith.constant 1 : i32
841
+ %c127_i32 = arith.constant 127 : i32
842
+ %c128_i32 = arith.constant 128 : i32
843
+ %c2_i32 = arith.constant 2 : i32
844
+ %cst_1 = arith.constant dense <0.000000e+00 > : tensor <128 x128 xf32 , #mma >
845
+ %0 = tt.get_program_id x : i32
846
+ %1 = arith.addi %M , %c127_i32 : i32
847
+ %2 = arith.divsi %1 , %c128_i32 : i32
848
+ %3 = arith.remsi %0 , %2 : i32
849
+ %4 = arith.divsi %0 , %2 : i32
850
+ %5 = arith.muli %3 , %c128_i32 : i32
851
+ %6 = tt.make_range {end = 128 : i32 , start = 0 : i32 } : tensor <128 xi32 , #ttg.slice <{dim = 1 , parent = #blocked }>>
852
+ %7 = tt.make_range {end = 128 : i32 , start = 0 : i32 } : tensor <128 xi32 , #ttg.slice <{dim = 1 , parent = #blocked2 }>>
853
+ %8 = tt.make_range {end = 128 : i32 , start = 0 : i32 } : tensor <128 xi32 , #ttg.slice <{dim = 0 , parent = #blocked2 }>>
854
+ %9 = tt.splat %5 : i32 -> tensor <128 xi32 , #ttg.slice <{dim = 1 , parent = #blocked }>>
855
+ %10 = tt.splat %5 : i32 -> tensor <128 xi32 , #ttg.slice <{dim = 1 , parent = #blocked2 }>>
856
+ %11 = arith.addi %9 , %6 : tensor <128 xi32 , #ttg.slice <{dim = 1 , parent = #blocked }>>
857
+ %12 = arith.addi %10 , %7 : tensor <128 xi32 , #ttg.slice <{dim = 1 , parent = #blocked2 }>>
858
+ %13 = arith.muli %4 , %c128_i32 : i32
859
+ %14 = arith.divsi %13 , %c2_i32 : i32
860
+ %15 = tt.make_range {end = 64 : i32 , start = 0 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 0 , parent = #blocked1 }>>
861
+ %16 = tt.splat %14 : i32 -> tensor <64 xi32 , #ttg.slice <{dim = 0 , parent = #blocked1 }>>
862
+ %17 = arith.addi %16 , %15 : tensor <64 xi32 , #ttg.slice <{dim = 0 , parent = #blocked1 }>>
863
+ %18 = tt.expand_dims %11 {axis = 1 : i32 } : tensor <128 xi32 , #ttg.slice <{dim = 1 , parent = #blocked }>> -> tensor <128 x1 xi32 , #blocked >
864
+ %19 = tt.expand_dims %12 {axis = 1 : i32 } : tensor <128 xi32 , #ttg.slice <{dim = 1 , parent = #blocked2 }>> -> tensor <128 x1 xi32 , #blocked2 >
865
+ %20 = tt.splat %stride_am : i32 -> tensor <128 x1 xi32 , #blocked >
866
+ %21 = arith.muli %18 , %20 : tensor <128 x1 xi32 , #blocked >
867
+ %22 = tt.make_range {end = 128 : i32 , start = 0 : i32 } : tensor <128 xi32 , #ttg.slice <{dim = 0 , parent = #blocked }>>
868
+ %23 = tt.expand_dims %22 {axis = 0 : i32 } : tensor <128 xi32 , #ttg.slice <{dim = 0 , parent = #blocked }>> -> tensor <1 x128 xi32 , #blocked >
869
+ %24 = tt.broadcast %21 : tensor <128 x1 xi32 , #blocked > -> tensor <128 x128 xi32 , #blocked >
870
+ %25 = tt.broadcast %23 : tensor <1 x128 xi32 , #blocked > -> tensor <128 x128 xi32 , #blocked >
871
+ %26 = arith.addi %24 , %25 : tensor <128 x128 xi32 , #blocked >
872
+ %27 = tt.splat %a_ptr : !tt.ptr <f8E5M2 > -> tensor <128 x128 x!tt.ptr <f8E5M2 >, #blocked >
873
+ %28 = tt.addptr %27 , %26 : tensor <128 x128 x!tt.ptr <f8E5M2 >, #blocked >, tensor <128 x128 xi32 , #blocked >
874
+ %29 = tt.make_range {end = 128 : i32 , start = 0 : i32 } : tensor <128 xi32 , #ttg.slice <{dim = 1 , parent = #blocked1 }>>
875
+ %30 = tt.expand_dims %29 {axis = 1 : i32 } : tensor <128 xi32 , #ttg.slice <{dim = 1 , parent = #blocked1 }>> -> tensor <128 x1 xi32 , #blocked1 >
876
+ %31 = tt.expand_dims %17 {axis = 0 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 0 , parent = #blocked1 }>> -> tensor <1 x64 xi32 , #blocked1 >
877
+ %32 = tt.splat %stride_bn : i32 -> tensor <1 x64 xi32 , #blocked1 >
878
+ %33 = arith.muli %31 , %32 : tensor <1 x64 xi32 , #blocked1 >
879
+ %34 = tt.broadcast %30 : tensor <128 x1 xi32 , #blocked1 > -> tensor <128 x64 xi32 , #blocked1 >
880
+ %35 = tt.broadcast %33 : tensor <1 x64 xi32 , #blocked1 > -> tensor <128 x64 xi32 , #blocked1 >
881
+ %36 = arith.addi %34 , %35 : tensor <128 x64 xi32 , #blocked1 >
882
+ %37 = tt.splat %b_ptr : !tt.ptr <i8 > -> tensor <128 x64 x!tt.ptr <i8 >, #blocked1 >
883
+ %38 = tt.addptr %37 , %36 : tensor <128 x64 x!tt.ptr <i8 >, #blocked1 >, tensor <128 x64 xi32 , #blocked1 >
884
+ %39 = arith.addi %K , %c127_i32 : i32
885
+ %40 = arith.divsi %39 , %c128_i32 : i32
886
+ %accumulator:3 = scf.for %accumulator_2 = %c0_i32 to %40 step %c1_i32 iter_args (%arg11 = %cst_1 , %arg12 = %28 , %arg13 = %38 ) -> (tensor <128 x128 xf32 , #mma >, tensor <128 x128 x!tt.ptr <f8E5M2 >, #blocked >, tensor <128 x64 x!tt.ptr <i8 >, #blocked1 >) : i32 {
887
+ %60 = tt.load %arg12 : tensor <128 x128 x!tt.ptr <f8E5M2 >, #blocked >
888
+ %61 = tt.load %arg13 : tensor <128 x64 x!tt.ptr <i8 >, #blocked1 >
889
+ %62 = ttg.convert_layout %60 : tensor <128 x128 xf8 E5 M2 , #blocked > -> tensor <128 x128 xf8 E5 M2 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 16 }>>
890
+ %63 = ttg.local_alloc %61 : (tensor <128 x64 xi8 , #blocked1 >) -> !ttg.memdesc <128 x64 xi8 , #shared , #smem >
891
+ %64 = amdgpu.local_load_packed_tranposed %63 : !ttg.memdesc <128 x64 xi8 , #shared , #smem > -> tensor <64 x128 xi8 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 16 }>>
892
+ %65 = tt.dot_scaled %62 , %64 , %arg11 lhs = e5m2 rhs = e2m1 {fastMath = false } : tensor <128 x128 xf8 E5 M2 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 16 }>> * tensor <64 x128 xi8 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 16 }>> -> tensor <128 x128 xf32 , #mma >
893
+ %66 = tt.addptr %arg12 , %cst : tensor <128 x128 x!tt.ptr <f8E5M2 >, #blocked >, tensor <128 x128 xi32 , #blocked >
894
+ %67 = tt.addptr %arg13 , %cst_0 : tensor <128 x64 x!tt.ptr <i8 >, #blocked1 >, tensor <128 x64 xi32 , #blocked1 >
895
+ scf.yield %65 , %66 , %67 : tensor <128 x128 xf32 , #mma >, tensor <128 x128 x!tt.ptr <f8E5M2 >, #blocked >, tensor <128 x64 x!tt.ptr <i8 >, #blocked1 >
896
+ } {tt.num_stages = 2 : i32 }
897
+ %41 = tt.splat %13 : i32 -> tensor <128 xi32 , #ttg.slice <{dim = 0 , parent = #blocked2 }>>
898
+ %42 = arith.addi %41 , %8 : tensor <128 xi32 , #ttg.slice <{dim = 0 , parent = #blocked2 }>>
899
+ %43 = tt.splat %stride_cm : i32 -> tensor <128 x1 xi32 , #blocked2 >
900
+ %44 = arith.muli %43 , %19 : tensor <128 x1 xi32 , #blocked2 >
901
+ %45 = tt.splat %output_ptr : !tt.ptr <f32 > -> tensor <128 x1 x!tt.ptr <f32 >, #blocked2 >
902
+ %46 = tt.addptr %45 , %44 : tensor <128 x1 x!tt.ptr <f32 >, #blocked2 >, tensor <128 x1 xi32 , #blocked2 >
903
+ %47 = tt.expand_dims %42 {axis = 0 : i32 } : tensor <128 xi32 , #ttg.slice <{dim = 0 , parent = #blocked2 }>> -> tensor <1 x128 xi32 , #blocked2 >
904
+ %48 = tt.broadcast %46 : tensor <128 x1 x!tt.ptr <f32 >, #blocked2 > -> tensor <128 x128 x!tt.ptr <f32 >, #blocked2 >
905
+ %49 = tt.broadcast %47 : tensor <1 x128 xi32 , #blocked2 > -> tensor <128 x128 xi32 , #blocked2 >
906
+ %50 = tt.addptr %48 , %49 : tensor <128 x128 x!tt.ptr <f32 >, #blocked2 >, tensor <128 x128 xi32 , #blocked2 >
907
+ %51 = tt.splat %M : i32 -> tensor <128 x1 xi32 , #blocked2 >
908
+ %52 = arith.cmpi slt , %19 , %51 : tensor <128 x1 xi32 , #blocked2 >
909
+ %53 = tt.splat %N : i32 -> tensor <1 x128 xi32 , #blocked2 >
910
+ %54 = arith.cmpi slt , %47 , %53 : tensor <1 x128 xi32 , #blocked2 >
911
+ %55 = tt.broadcast %52 : tensor <128 x1 xi1 , #blocked2 > -> tensor <128 x128 xi1 , #blocked2 >
912
+ %56 = tt.broadcast %54 : tensor <1 x128 xi1 , #blocked2 > -> tensor <128 x128 xi1 , #blocked2 >
913
+ %57 = arith.andi %55 , %56 : tensor <128 x128 xi1 , #blocked2 >
914
+ %58 = ttg.convert_layout %50 : tensor <128 x128 x!tt.ptr <f32 >, #blocked2 > -> tensor <128 x128 x!tt.ptr <f32 >, #mma >
915
+ %59 = ttg.convert_layout %57 : tensor <128 x128 xi1 , #blocked2 > -> tensor <128 x128 xi1 , #mma >
916
+ tt.store %58 , %accumulator#0 , %59 : tensor <128 x128 x!tt.ptr <f32 >, #mma >
917
+ tt.return
918
+ }
919
+ }
0 commit comments