@@ -377,6 +377,40 @@ module attributes {triton_intel_gpu.min_sg_size = 16 : i32, triton_intel_gpu.sup
377377 } {tt.flatten }
378378 tt.return
379379 }
380+ }
380381
382+ // -----
381383
384+ // COM: Ensure prefetch operations aren't generated for 3D loads.
385+ #linear = #ttg.linear <{register = [[0 , 1 , 0 ], [0 , 2 , 0 ], [0 , 4 , 0 ], [0 , 8 , 0 ], [0 , 16 , 0 ], [0 , 0 , 16 ], [0 , 0 , 32 ], [0 , 64 , 0 ]], lane = [[0 , 0 , 1 ], [0 , 0 , 2 ], [0 , 0 , 4 ], [0 , 0 , 8 ]], warp = [[0 , 0 , 0 ], [0 , 0 , 0 ], [0 , 32 , 0 ]], block = []}>
386+ #linear1 = #ttg.linear <{register = [[0 , 0 , 1 ], [0 , 0 , 2 ], [0 , 0 , 4 ], [0 , 0 , 8 ], [0 , 16 , 0 ], [0 , 0 , 16 ], [0 , 0 , 32 ], [0 , 128 , 0 ]], lane = [[0 , 1 , 0 ], [0 , 2 , 0 ], [0 , 4 , 0 ], [0 , 8 , 0 ]], warp = [[0 , 32 , 0 ], [0 , 64 , 0 ], [0 , 0 , 0 ]], block = []}>
387+ #linear2 = #ttg.linear <{register = [[0 , 1 ], [0 , 2 ], [0 , 4 ], [0 , 8 ], [16 , 0 ], [0 , 16 ], [0 , 32 ], [128 , 0 ]], lane = [[1 , 0 ], [2 , 0 ], [4 , 0 ], [8 , 0 ]], warp = [[32 , 0 ], [64 , 0 ], [0 , 0 ]], block = []}>
388+ #mma = #triton_intel_gpu.dpas <{repeatCount = 8 , systolicDepth = 8 , executionSize = 16 , opsPerChan = 2 , threadsPerWarp = 16 , warpsPerCTA = [2 , 4 ], repCluster = [4 , 2 ], A = [32 , 16 ], B = [16 , 32 ], C = [32 , 32 ]}>
389+ module attributes {triton_intel_gpu.support_dpas , triton_intel_gpu.support_sg_2d_block , triton_intel_gpu.target_arch = " spir64" , " ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 8 : i32 , ttg.target = " xpu" , " ttg.threads-per-warp" = 16 : i32 } {
390+ // CHECK-LABEL: batched_gemm_3d_tma_kernel
391+ tt.func public @batched_gemm_3d_tma_kernel (%arg0: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }, %arg1: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }, %arg3: i32 , %arg6: i32 {tt.divisibility = 16 : i32 }) {
392+ %c1_i32 = arith.constant 1 : i32
393+ %c0_i32 = arith.constant 0 : i32
394+ %c64_i32 = arith.constant 64 : i32
395+ %c1_i64 = arith.constant 1 : i64
396+ %cst = arith.constant dense <0.000000e+00 > : tensor <128 x256 xf32 , #mma >
397+ %0 = tt.get_program_id x : i32
398+ %16 = arith.extsi %arg6 : i32 to i64
399+ %24 = arith.extsi %arg3 : i32 to i64
400+ %26:1 = scf.for %arg7 = %c0_i32 to %c64_i32 step %c1_i32 iter_args (%arg8 = %c1_i32 ) -> (i32 ) : i32 {
401+ // CHECK-NOT: prefetch
402+ %27 = arith.cmpi eq , %arg8 , %c1_i32 : i32
403+ %29 = arith.select %27 , %c0_i32 , %arg8 : i32
404+ %33 = tt.make_tensor_ptr %arg0 , [%24 , %24 , %16 ], [%16 , %16 , %c1_i64 ], [%arg8 , %arg8 , %29 ] {order = array<i32 : 1 , 0 >} : <tensor <1 x128 x64 xf16 , #linear >>
405+ %34 = tt.load %33 {triton_intel_gpu.block_io = " row_major" } : !tt.ptr <tensor <1 x128 x64 xf16 , #linear >>
406+ %35 = tt.reshape %34 : tensor <1 x128 x64 xf16 , #linear > -> tensor <128 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>>
407+ %36 = tt.make_tensor_ptr %arg1 , [%24 , %16 , %16 ], [%16 , %16 , %c1_i64 ], [%arg8 , %arg8 , %29 ] {order = array<i32 : 1 , 0 >} : <tensor <1 x256 x64 xf16 , #linear1 >>
408+ %37 = tt.load %36 {triton_intel_gpu.block_io = " row_major" } : !tt.ptr <tensor <1 x256 x64 xf16 , #linear1 >>
409+ %38 = tt.reshape %37 : tensor <1 x256 x64 xf16 , #linear1 > -> tensor <256 x64 xf16 , #linear2 >
410+ %39 = tt.trans %38 {order = array<i32 : 1 , 0 >} : tensor <256 x64 xf16 , #linear2 > -> tensor <64 x256 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
411+ %40 = tt.dot %35 , %39 , %cst : tensor <128 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>> * tensor <64 x256 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>> -> tensor <128 x256 xf32 , #mma >
412+ scf.yield %29 : i32
413+ }
414+ tt.return
415+ }
382416}
0 commit comments