@@ -2297,3 +2297,112 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 :
22972297 tt.return %3 : tensor <128 x256 xf32 , #blocked >
22982298 }
22992299}
2300+
2301+
2302+ // -----
2303+
2304+ // COM: Check that dpas layout can be propagated from dot op to atomic_rmw op
2305+ // CHECK-NOT: #triton_gpu.blocked<{.*}>
2306+ // CHECK: #[[$DPAS:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
2307+ #blocked = #triton_gpu.blocked <{sizePerThread = [1 ], threadsPerWarp = [16 ], warpsPerCTA = [32 ], order = [0 ]}>
2308+ #blocked1 = #triton_gpu.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [16 , 1 ], warpsPerCTA = [32 , 1 ], order = [1 , 0 ]}>
2309+ #blocked2 = #triton_gpu.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 16 ], warpsPerCTA = [2 , 16 ], order = [1 , 0 ]}>
2310+ #blocked3 = #triton_gpu.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 16 ], warpsPerCTA = [16 , 2 ], order = [1 , 0 ]}>
2311+ #blocked4 = #triton_gpu.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 16 ], warpsPerCTA = [1 , 32 ], order = [0 , 1 ]}>
2312+ #mma = #triton_intel_gpu.dpas <{repeatCount = 8 , systolicDepth = 8 , executionSize = 16 , opsPerChan = 2 , threadsPerWarp = 16 , warpsPerCTA = [8 , 4 ], repCluster = [4 , 2 ], A = [32 , 16 ], B = [16 , 32 ], C = [32 , 32 ]}>
2313+ module attributes {" triton_gpu.num-ctas" = 1 : i32 , " triton_gpu.num-warps" = 32 : i32 , " triton_gpu.threads-per-warp" = 16 : i32 } {
2314+ // CHECK-LABEL: tt.func public @propagate_mma_to_atomic_rmw
2315+ tt.func public @propagate_mma_to_atomic_rmw (%arg0: !tt.ptr <bf16 >, %arg1: !tt.ptr <bf16 >, %arg2: !tt.ptr <f32 >) attributes {noinline = false } {
2316+ %c0_i32 = arith.constant 0 : i32
2317+ %c1_i64 = arith.constant 1 : i64
2318+ %c32_i32 = arith.constant 32 : i32
2319+ %c128_i32 = arith.constant 128 : i32
2320+ %c256_i32 = arith.constant 256 : i32
2321+ %c4096_i32 = arith.constant 4096 : i32
2322+ %c4096_i64 = arith.constant 4096 : i64
2323+ %cst = arith.constant dense <4096 > : tensor <256 xi32 , #blocked >
2324+ %cst_1 = arith.constant dense <0.000000e+00 > : tensor <256 x256 xf32 , #blocked2 >
2325+ %0 = tt.get_program_id x : i32
2326+ %1 = tt.get_program_id y : i32
2327+ // CHECK: %[[VAL_0:.*]] = tt.make_tensor_ptr {{.*}} : <tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 2}>>>
2328+ // CHECK: %[[VAL_1:.*]] = tt.make_tensor_ptr {{.*}} : <tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>>
2329+ %12 = tt.make_tensor_ptr %arg0 , [%c4096_i64 , %c4096_i64 ], [%c4096_i64 , %c1_i64 ], [%0 , %1 ] {order = array<i32 : 1 , 0 >} : <tensor <256 x32 xbf16 , #blocked3 >>
2330+ %14 = tt.make_tensor_ptr %arg1 , [%c4096_i64 , %c4096_i64 ], [%c4096_i64 , %c1_i64 ], [%0 , %1 ] {order = array<i32 : 1 , 0 >} : <tensor <32 x256 xbf16 , #blocked2 >>
2331+ // CHECK: %[[VAL_2:.*]]:3 = scf.for {{.*}} -> (tensor<256x256xf32, #[[$DPAS]]>, !tt.ptr<tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 2}>>>, !tt.ptr<tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>>) : i32 {
2332+ %15:3 = scf.for %arg3 = %c0_i32 to %c4096_i32 step %c128_i32 iter_args (%arg4 = %cst_1 , %arg5 = %12 , %arg6 = %14 ) -> (tensor <256 x256 xf32 , #blocked2 >, !tt.ptr <tensor <256 x32 xbf16 , #blocked3 >>, !tt.ptr <tensor <32 x256 xbf16 , #blocked2 >>) : i32 {
2333+ %47 = tt.load %arg5 : !tt.ptr <tensor <256 x32 xbf16 , #blocked3 >>
2334+ %48 = tt.load %arg6 : !tt.ptr <tensor <32 x256 xbf16 , #blocked2 >>
2335+ // CHEKC-NOT: triton_gpu.convert_layout
2336+ %49 = triton_gpu.convert_layout %arg4 : tensor <256 x256 xf32 , #blocked2 > -> tensor <256 x256 xf32 , #mma >
2337+ %50 = triton_gpu.convert_layout %47 : tensor <256 x32 xbf16 , #blocked3 > -> tensor <256 x32 xbf16 , #triton_gpu.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>>
2338+ %51 = triton_gpu.convert_layout %48 : tensor <32 x256 xbf16 , #blocked2 > -> tensor <32 x256 xbf16 , #triton_gpu.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
2339+ %52 = tt.dot %50 , %51 , %49 , inputPrecision = tf32 : tensor <256 x32 xbf16 , #triton_gpu.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>> * tensor <32 x256 xbf16 , #triton_gpu.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>> -> tensor <256 x256 xf32 , #mma >
2340+ %53 = triton_gpu.convert_layout %52 : tensor <256 x256 xf32 , #mma > -> tensor <256 x256 xf32 , #blocked2 >
2341+ // CHECK: %[[VAL_3:.*]] = tt.advance {{.*}} : <tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 2}>>>
2342+ // CHECK: %[[VAL_4:.*]] = tt.advance {{.*}} : <tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>>
2343+ // CHECK: scf.yield {{.*}} : tensor<256x256xf32, #[[$DPAS]]>, !tt.ptr<tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 2}>>>, !tt.ptr<tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>>
2344+ %54 = tt.advance %arg5 , [%c0_i32 , %c128_i32 ] : <tensor <256 x32 xbf16 , #blocked3 >>
2345+ %55 = tt.advance %arg6 , [%c128_i32 , %c0_i32 ] : <tensor <32 x256 xbf16 , #blocked2 >>
2346+ scf.yield %53 , %54 , %55 : tensor <256 x256 xf32 , #blocked2 >, !tt.ptr <tensor <256 x32 xbf16 , #blocked3 >>, !tt.ptr <tensor <32 x256 xbf16 , #blocked2 >>
2347+ }
2348+ %16 = tt.make_range {end = 256 : i32 , start = 0 : i32 } : tensor <256 xi32 , #blocked >
2349+ %32 = tt.splat %arg2 : !tt.ptr <f32 > -> tensor <256 x256 x!tt.ptr <f32 >, #blocked2 >
2350+ %38 = arith.cmpi slt , %16 , %cst : tensor <256 xi32 , #blocked >
2351+ // CHEKC-NOT: triton_gpu.convert_layout
2352+ %39 = triton_gpu.convert_layout %38 : tensor <256 xi1 , #blocked > -> tensor <256 xi1 , #triton_gpu.slice <{dim = 0 , parent = #blocked4 }>>
2353+ %40 = tt.expand_dims %39 {axis = 0 : i32 } : tensor <256 xi1 , #triton_gpu.slice <{dim = 0 , parent = #blocked4 }>> -> tensor <1 x256 xi1 , #blocked4 >
2354+ %41 = triton_gpu.convert_layout %40 : tensor <1 x256 xi1 , #blocked4 > -> tensor <1 x256 xi1 , #blocked2 >
2355+ %42 = tt.broadcast %41 : tensor <1 x256 xi1 , #blocked2 > -> tensor <256 x256 xi1 , #blocked2 >
2356+ // CHECK: %[[VAL_5:.*]] = tt.atomic_rmw fadd, acq_rel, gpu, {{.*}} : (tensor<256x256x!tt.ptr<f32>, #[[$DPAS]]>, tensor<256x256xf32, #[[$DPAS]]>, tensor<256x256xi1, #[[$DPAS]]>) -> tensor<256x256xf32, #[[$DPAS]]>
2357+ %46 = tt.atomic_rmw fadd , acq_rel , gpu , %32 , %15#0 , %42 : (tensor <256 x256 x!tt.ptr <f32 >, #blocked2 >, tensor <256 x256 xf32 , #blocked2 >, tensor <256 x256 xi1 , #blocked2 >) -> tensor <256 x256 xf32 , #blocked2 >
2358+ tt.return
2359+ }
2360+ }
2361+
2362+
2363+ // -----
2364+
2365+ // COM: Check that bare atomic_rmw op with blocked layout can still be propagated to dpas layout
2366+ // COM: Blocked layout will not backpropagate to overwrite dpas layout
2367+ // CHECK-NOT: #triton_gpu.blocked<{.*}>
2368+ // CHECK: #[[$DPAS:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
2369+ #blocked = #triton_gpu.blocked <{sizePerThread = [1 , 4 ], threadsPerWarp = [1 , 16 ], warpsPerCTA = [8 , 4 ], order = [1 , 0 ]}>
2370+ #mma = #triton_intel_gpu.dpas <{repeatCount = 8 , systolicDepth = 8 , executionSize = 16 , opsPerChan = 2 , threadsPerWarp = 16 , warpsPerCTA = [8 , 4 ], repCluster = [4 , 2 ], A = [32 , 16 ], B = [16 , 32 ], C = [32 , 32 ]}>
2371+ module attributes {" triton_gpu.num-ctas" = 1 : i32 , " triton_gpu.num-warps" = 32 : i32 , " triton_gpu.threads-per-warp" = 16 : i32 } {
2372+ // CHECK-LABEL: tt.func public @bare_atomic_with_blocked_layout
2373+ tt.func public @bare_atomic_with_blocked_layout (%arg0: !tt.ptr <bf16 >, %arg1: !tt.ptr <bf16 >, %arg2: !tt.ptr <f32 >) attributes {noinline = false } {
2374+ %cst = arith.constant dense <0.000000e+00 > : tensor <256 x256 xf32 , #mma >
2375+ %cst_0 = arith.constant dense <3072 > : tensor <256 xi32 , #triton_gpu.slice <{dim = 1 , parent = #mma }>>
2376+ %c1_i64 = arith.constant 1 : i64
2377+ %c0_i32 = arith.constant 0 : i32
2378+ %c128_i32 = arith.constant 128 : i32
2379+ %c4096_i64 = arith.constant 4096 : i64
2380+ %c4096_i32 = arith.constant 4096 : i32
2381+ %0 = tt.get_program_id x : i32
2382+ %1 = tt.get_program_id y : i32
2383+ %12 = tt.make_tensor_ptr %arg0 , [%c4096_i64 , %c4096_i64 ], [%c4096_i64 , %c1_i64 ], [%0 , %1 ] {order = array<i32 : 1 , 0 >} : <tensor <256 x32 xbf16 , #triton_gpu.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>>>
2384+ %14 = tt.make_tensor_ptr %arg1 , [%c4096_i64 , %c4096_i64 ], [%c4096_i64 , %c1_i64 ], [%0 , %1 ] {order = array<i32 : 1 , 0 >} : <tensor <32 x256 xbf16 , #triton_gpu.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>>
2385+ %15:3 = scf.for %arg3 = %c0_i32 to %c4096_i32 step %c128_i32 iter_args (%arg4 = %cst , %arg5 = %12 , %arg6 = %14 ) -> (tensor <256 x256 xf32 , #mma >, !tt.ptr <tensor <256 x32 xbf16 , #triton_gpu.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>>>, !tt.ptr <tensor <32 x256 xbf16 , #triton_gpu.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>>) : i32 {
2386+ %41 = tt.advance %arg5 , [%c0_i32 , %c128_i32 ] : <tensor <256 x32 xbf16 , #triton_gpu.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>>>
2387+ %42 = tt.advance %arg6 , [%c128_i32 , %c0_i32 ] : <tensor <32 x256 xbf16 , #triton_gpu.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>>
2388+ %43 = tt.load %arg5 {triton_intel_gpu.block_io = " row_major" } : !tt.ptr <tensor <256 x32 xbf16 , #triton_gpu.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>>>
2389+ %44 = tt.load %arg6 {triton_intel_gpu.block_io = " row_major" } : !tt.ptr <tensor <32 x256 xbf16 , #triton_gpu.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>>
2390+ %45 = tt.dot %43 , %44 , %arg4 , inputPrecision = tf32 : tensor <256 x32 xbf16 , #triton_gpu.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>> * tensor <32 x256 xbf16 , #triton_gpu.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>> -> tensor <256 x256 xf32 , #mma >
2391+ scf.yield %45 , %41 , %42 : tensor <256 x256 xf32 , #mma >, !tt.ptr <tensor <256 x32 xbf16 , #triton_gpu.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>>>, !tt.ptr <tensor <32 x256 xbf16 , #triton_gpu.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>>
2392+ }
2393+ %18 = tt.splat %0 : i32 -> tensor <256 xi32 , #triton_gpu.slice <{dim = 1 , parent = #mma }>>
2394+ %28 = tt.splat %arg2 : !tt.ptr <f32 > -> tensor <256 x256 x!tt.ptr <f32 >, #mma >
2395+ %30 = arith.cmpi slt , %18 , %cst_0 : tensor <256 xi32 , #triton_gpu.slice <{dim = 1 , parent = #mma }>>
2396+ %31 = tt.expand_dims %30 {axis = 1 : i32 } : tensor <256 xi1 , #triton_gpu.slice <{dim = 1 , parent = #mma }>> -> tensor <256 x1 xi1 , #mma >
2397+ %34 = tt.broadcast %31 : tensor <256 x1 xi1 , #mma > -> tensor <256 x256 xi1 , #mma >
2398+ // CHECK-NOT: triton_gpu.convert_layout
2399+ %37 = triton_gpu.convert_layout %28 : tensor <256 x256 x!tt.ptr <f32 >, #mma > -> tensor <256 x256 x!tt.ptr <f32 >, #blocked >
2400+ %38 = triton_gpu.convert_layout %15#0 : tensor <256 x256 xf32 , #mma > -> tensor <256 x256 xf32 , #blocked >
2401+ %39 = triton_gpu.convert_layout %34 : tensor <256 x256 xi1 , #mma > -> tensor <256 x256 xi1 , #blocked >
2402+ // CHECK: %[[VAL_0:.*]] = tt.atomic_rmw fadd, acq_rel, gpu, {{.*}} : (tensor<256x256x!tt.ptr<f32>, #[[$DPAS]]>, tensor<256x256xf32, #[[$DPAS]]>, tensor<256x256xi1, #[[$DPAS]]>) -> tensor<256x256xf32, #[[$DPAS]]>
2403+ %40 = tt.atomic_rmw fadd , acq_rel , gpu , %37 , %38 , %39 : (tensor <256 x256 x!tt.ptr <f32 >, #blocked >, tensor <256 x256 xf32 , #blocked >, tensor <256 x256 xi1 , #blocked >) -> tensor <256 x256 xf32 , #blocked >
2404+ // CHECK-NOT: triton_gpu.convert_layout
2405+ %41 = triton_gpu.convert_layout %40 : tensor <256 x256 xf32 , #blocked > -> tensor <256 x256 xf32 , #mma >
2406+ tt.return
2407+ }
2408+ }
0 commit comments