@@ -700,3 +700,42 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ
700700 tt.return %75#0 : tensor <128 x256 xf32 , #blocked3 >
701701 }
702702}
703+
704+ // -----
705+
706+ // Check we do not get AsyncCopyGlobalToLocal because the vec width will be < 32bit.
707+ // The order of the shared memory will be getMemoryOrder(#linear1) == [0, 1]
708+ // which differs from the order [1, 0] of the blocked layout. Since we have to
709+ // gather into lds with AsyncCopyGlobalToLocal we have to fallback to registers
710+
711+ // COMMON-LABEL: pipeline_scale_memory_order
712+ // COMMON-NOT: ttg.async_copy_global_to_local
713+
714+ #blocked = #ttg.blocked <{sizePerThread = [1 , 8 ], threadsPerWarp = [64 , 1 ], warpsPerCTA = [8 , 1 ], order = [1 , 0 ]}>
715+ #linear = #ttg.linear <{register = [[0 , 4 ], [16 , 0 ], [32 , 0 ], [64 , 0 ]], lane = [[1 , 0 ], [2 , 0 ], [4 , 0 ], [8 , 0 ], [0 , 1 ], [0 , 2 ]], warp = [[0 , 0 ], [0 , 0 ], [0 , 0 ]], block = []}>
716+ #linear1 = #ttg.linear <{register = [[0 , 4 ], [128 , 0 ], [256 , 0 ]], lane = [[1 , 0 ], [2 , 0 ], [4 , 0 ], [8 , 0 ], [0 , 1 ], [0 , 2 ]], warp = [[16 , 0 ], [32 , 0 ], [64 , 0 ]], block = []}>
717+ #mma = #ttg.amd_mfma <{versionMajor = 4 , versionMinor = 0 , warpsPerCTA = [1 , 8 ], instrShape = [16 , 16 ], isTransposed = true }>
718+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 8 : i32 , ttg.target = " hip:gfx950" , " ttg.threads-per-warp" = 64 : i32 } {
719+ tt.func public @pipeline_scale_memory_order (%arg0: !tt.ptr <i8 > {tt.divisibility = 16 : i32 }, %arg1: i64 {tt.divisibility = 16 : i32 }, %arg2: tensor <128 x256 xf8 E4 M3 FN, #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 16 }>>, %arg3: tensor <128 x512 xi8 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 16 }>>, %arg4: tensor <128 x512 x!tt.ptr <f32 >, #mma >, %arg5: tensor <512 x8 x!tt.ptr <i8 >, #blocked >) attributes {noinline = false } {
720+ %cst = arith.constant dense <127 > : tensor <128 x8 xi8 , #linear >
721+ %cst_0 = arith.constant dense <8 > : tensor <512 x8 xi32 , #blocked >
722+ %c256_i64 = arith.constant 256 : i64
723+ %c0_i64 = arith.constant 0 : i64
724+ %cst_1 = arith.constant dense <0.000000e+00 > : tensor <128 x512 xf32 , #mma >
725+ %0 = tt.make_range {end = 8 : i32 , start = 0 : i32 } : tensor <8 xi32 , #ttg.slice <{dim = 0 , parent = #blocked }>>
726+ %1 = arith.extsi %0 : tensor <8 xi32 , #ttg.slice <{dim = 0 , parent = #blocked }>> to tensor <8 xi64 , #ttg.slice <{dim = 0 , parent = #blocked }>>
727+ %2 = tt.expand_dims %1 {axis = 0 : i32 } : tensor <8 xi64 , #ttg.slice <{dim = 0 , parent = #blocked }>> -> tensor <1 x8 xi64 , #blocked >
728+ %3 = tt.splat %arg0 : !tt.ptr <i8 > -> tensor <1 x8 x!tt.ptr <i8 >, #blocked >
729+ %4 = tt.addptr %3 , %2 : tensor <1 x8 x!tt.ptr <i8 >, #blocked >, tensor <1 x8 xi64 , #blocked >
730+ %5 = tt.broadcast %4 : tensor <1 x8 x!tt.ptr <i8 >, #blocked > -> tensor <512 x8 x!tt.ptr <i8 >, #blocked >
731+ %6:2 = scf.for %arg6 = %c0_i64 to %arg1 step %c256_i64 iter_args (%arg7 = %cst_1 , %arg8 = %5 ) -> (tensor <128 x512 xf32 , #mma >, tensor <512 x8 x!tt.ptr <i8 >, #blocked >) : i64 {
732+ %7 = tt.load %arg8 : tensor <512 x8 x!tt.ptr <i8 >, #blocked >
733+ %8 = ttg.convert_layout %7 : tensor <512 x8 xi8 , #blocked > -> tensor <512 x8 xi8 , #linear1 >
734+ %9 = tt.dot_scaled %arg2 scale %cst , %arg3 scale %8 , %arg7 lhs = e4m3 rhs = e2m1 {fastMath = true } : tensor <128 x256 xf8 E4 M3 FN, #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 16 }>>, tensor <128 x8 xi8 , #linear > * tensor <128 x512 xi8 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 16 }>>, tensor <512 x8 xi8 , #linear1 > -> tensor <128 x512 xf32 , #mma >
735+ %10 = tt.addptr %arg8 , %cst_0 : tensor <512 x8 x!tt.ptr <i8 >, #blocked >, tensor <512 x8 xi32 , #blocked >
736+ scf.yield %9 , %10 : tensor <128 x512 xf32 , #mma >, tensor <512 x8 x!tt.ptr <i8 >, #blocked >
737+ }
738+ tt.store %arg4 , %6#0 : tensor <128 x512 x!tt.ptr <f32 >, #mma >
739+ tt.return
740+ }
741+ }
0 commit comments