@@ -2243,3 +2243,76 @@ func.func @test_firstprivate_map(%arg0: memref<10xf32>) {
22432243// CHECK-NEXT: acc.yield
22442244// CHECK-NEXT: }
22452245// CHECK-NEXT: return
2246+
2247+ // -----
2248+
2249+ func.func @test_kernel_environment (%arg0: memref <1024 xf32 >, %arg1: memref <1024 xf32 >) {
2250+ %c1 = arith.constant 1 : index
2251+ %c1024 = arith.constant 1024 : index
2252+
2253+ // Create data clause operands for the kernel environment
2254+ %copyin = acc.copyin varPtr (%arg0 : memref <1024 xf32 >) -> memref <1024 xf32 >
2255+ %create = acc.create varPtr (%arg1 : memref <1024 xf32 >) -> memref <1024 xf32 >
2256+
2257+ // Kernel environment wraps gpu.launch and captures data mapping
2258+ acc.kernel_environment dataOperands (%copyin , %create : memref <1024 xf32 >, memref <1024 xf32 >) {
2259+ gpu.launch blocks (%bx , %by , %bz ) in (%grid_x = %c1 , %grid_y = %c1 , %grid_z = %c1 )
2260+ threads (%tx , %ty , %tz ) in (%block_x = %c1024 , %block_y = %c1 , %block_z = %c1 ) {
2261+ // Kernel body uses the mapped data
2262+ %val = memref.load %copyin [%tx ] : memref <1024 xf32 >
2263+ %result = arith.mulf %val , %val : f32
2264+ memref.store %result , %create [%tx ] : memref <1024 xf32 >
2265+ gpu.terminator
2266+ }
2267+ }
2268+
2269+ // Copy results back to host and deallocate device memory
2270+ acc.copyout accPtr (%create : memref <1024 xf32 >) to varPtr (%arg1 : memref <1024 xf32 >)
2271+ acc.delete accPtr (%copyin : memref <1024 xf32 >)
2272+
2273+ return
2274+ }
2275+
2276+ // CHECK-LABEL: func @test_kernel_environment
2277+ // CHECK: %[[COPYIN:.*]] = acc.copyin varPtr(%{{.*}} : memref<1024xf32>) -> memref<1024xf32>
2278+ // CHECK: %[[CREATE:.*]] = acc.create varPtr(%{{.*}} : memref<1024xf32>) -> memref<1024xf32>
2279+ // CHECK: acc.kernel_environment dataOperands(%[[COPYIN]], %[[CREATE]] : memref<1024xf32>, memref<1024xf32>) {
2280+ // CHECK: gpu.launch
2281+ // CHECK: memref.load %[[COPYIN]]
2282+ // CHECK: memref.store %{{.*}}, %[[CREATE]]
2283+ // CHECK: }
2284+ // CHECK: }
2285+ // CHECK: acc.copyout accPtr(%[[CREATE]] : memref<1024xf32>) to varPtr(%{{.*}} : memref<1024xf32>)
2286+ // CHECK: acc.delete accPtr(%[[COPYIN]] : memref<1024xf32>)
2287+
2288+ // -----
2289+
2290+ func.func @test_kernel_environment_with_async (%arg0: memref <1024 xf32 >) {
2291+ %c1 = arith.constant 1 : index
2292+ %c1024 = arith.constant 1024 : index
2293+ %async_val = arith.constant 1 : i32
2294+
2295+ %create = acc.create varPtr (%arg0 : memref <1024 xf32 >) async (%async_val : i32 ) -> memref <1024 xf32 >
2296+
2297+ // Kernel environment with async clause
2298+ acc.kernel_environment dataOperands (%create : memref <1024 xf32 >) async (%async_val : i32 ) {
2299+ gpu.launch blocks (%bx , %by , %bz ) in (%grid_x = %c1 , %grid_y = %c1 , %grid_z = %c1 )
2300+ threads (%tx , %ty , %tz ) in (%block_x = %c1024 , %block_y = %c1 , %block_z = %c1 ) {
2301+ %f0 = arith.constant 0.0 : f32
2302+ memref.store %f0 , %create [%tx ] : memref <1024 xf32 >
2303+ gpu.terminator
2304+ }
2305+ }
2306+
2307+ acc.copyout accPtr (%create : memref <1024 xf32 >) async (%async_val : i32 ) to varPtr (%arg0 : memref <1024 xf32 >)
2308+
2309+ return
2310+ }
2311+
2312+ // CHECK-LABEL: func @test_kernel_environment_with_async
2313+ // CHECK: %[[ASYNC:.*]] = arith.constant 1 : i32
2314+ // CHECK: %[[CREATE:.*]] = acc.create varPtr(%{{.*}} : memref<1024xf32>) async(%[[ASYNC]] : i32) -> memref<1024xf32>
2315+ // CHECK: acc.kernel_environment dataOperands(%[[CREATE]] : memref<1024xf32>) async(%[[ASYNC]] : i32)
2316+ // CHECK: gpu.launch
2317+ // CHECK: memref.store %{{.*}}, %[[CREATE]]
2318+ // CHECK: acc.copyout accPtr(%[[CREATE]] : memref<1024xf32>) async(%[[ASYNC]] : i32) to varPtr(%{{.*}} : memref<1024xf32>)
0 commit comments