@@ -250,8 +250,7 @@ gpu.module @test_kernel {
250250// -----
251251#l = #xegpu.layout <inst_data = [16 , 16 ]>
252252#r = #xegpu.layout <inst_data = [16 ]>
253-
254- gpu.module @kernel attributes {spirv.target_env = #spirv.target_env <#spirv.vce <v1.4 , [Addresses , Float16Buffer , Int64 , Int16 , Int8 , Kernel , Linkage , Vector16 , GenericPointer , Groups , Float16 , Float64 , AtomicFloat32AddEXT , ExpectAssumeKHR , SubgroupDispatch , VectorComputeINTEL , VectorAnyINTEL ], [SPV_EXT_shader_atomic_float_add , SPV_KHR_expect_assume , SPV_INTEL_vector_compute ]>, api =OpenCL , #spirv.resource_limits <>>} {
253+ gpu.module @test_kernel {
255254 gpu.func @reduce_dim_0 (%a: memref <16 x512 xf32 >, %b: memref <512 xf32 >) kernel attributes {VectorComputeFunctionINTEL , spirv.entry_point_abi = #spirv.entry_point_abi <>} {
256255 %acc = arith.constant dense <0.0 > : vector <64 xf32 >
257256 %c64 = arith.constant 64 : index
@@ -271,8 +270,7 @@ gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<
271270// -----
272271#l = #xegpu.layout <inst_data = [16 , 16 ]>
273272#r = #xegpu.layout <inst_data = [16 ]>
274-
275- gpu.module @kernel attributes {spirv.target_env = #spirv.target_env <#spirv.vce <v1.4 , [Addresses , Float16Buffer , Int64 , Int16 , Int8 , Kernel , Linkage , Vector16 , GenericPointer , Groups , Float16 , Float64 , AtomicFloat32AddEXT , ExpectAssumeKHR , SubgroupDispatch , VectorComputeINTEL , VectorAnyINTEL ], [SPV_EXT_shader_atomic_float_add , SPV_KHR_expect_assume , SPV_INTEL_vector_compute ]>, api =OpenCL , #spirv.resource_limits <>>} {
273+ gpu.module @test_kernel {
276274 gpu.func @reduce_dim_1 (%a: memref <512 x32 xf32 >, %b: memref <512 xf32 >) kernel attributes {VectorComputeFunctionINTEL , spirv.entry_point_abi = #spirv.entry_point_abi <>} {
277275 %c1 = arith.constant 1 : index
278276 %c32 = arith.constant 32 : index
@@ -299,8 +297,7 @@ gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<
299297// -----
300298#r = #xegpu.layout <inst_data = [16 ]>
301299#l = #xegpu.layout <inst_data = [16 , 16 ]>
302-
303- gpu.module @kernel attributes {spirv.target_env = #spirv.target_env <#spirv.vce <v1.4 , [Addresses , Float16Buffer , Int64 , Int16 , Int8 , Kernel , Linkage , Vector16 , GenericPointer , Groups , Float16 , Float64 , AtomicFloat32AddEXT , ExpectAssumeKHR , SubgroupDispatch , VectorComputeINTEL , VectorAnyINTEL ], [SPV_EXT_shader_atomic_float_add , SPV_KHR_expect_assume , SPV_INTEL_vector_compute ]>, api =OpenCL , #spirv.resource_limits <>>} {
300+ gpu.module @test_kernel {
304301 gpu.func @broadcast_dim_0 (%a: memref <512 xf32 >, %b: memref <16 x512 xf32 >) kernel attributes {VectorComputeFunctionINTEL , spirv.entry_point_abi = #spirv.entry_point_abi <>} {
305302
306303 %c64 = arith.constant 64 : index
@@ -319,8 +316,7 @@ gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<
319316// -----
320317#r = #xegpu.layout <inst_data = [16 ]>
321318#l = #xegpu.layout <inst_data = [16 , 16 ]>
322-
323- gpu.module @kernel attributes {spirv.target_env = #spirv.target_env <#spirv.vce <v1.4 , [Addresses , Float16Buffer , Int64 , Int16 , Int8 , Kernel , Linkage , Vector16 , GenericPointer , Groups , Float16 , Float64 , AtomicFloat32AddEXT , ExpectAssumeKHR , SubgroupDispatch , VectorComputeINTEL , VectorAnyINTEL ], [SPV_EXT_shader_atomic_float_add , SPV_KHR_expect_assume , SPV_INTEL_vector_compute ]>, api =OpenCL , #spirv.resource_limits <>>} {
319+ gpu.module @test_kernel {
324320 gpu.func @broadcast_dim_1 (%a: memref <512 xf32 >, %b: memref <16 x512 xf32 >) kernel attributes {VectorComputeFunctionINTEL , spirv.entry_point_abi = #spirv.entry_point_abi <>} {
325321
326322 %c32 = arith.constant 32 : index
@@ -340,8 +336,7 @@ gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<
340336// -----
341337#l = #xegpu.layout <inst_data = [16 , 8 ]>
342338#t = #xegpu.layout <inst_data = [8 , 16 ]>
343-
344- gpu.module @kernel attributes {spirv.target_env = #spirv.target_env <#spirv.vce <v1.4 , [Addresses , Float16Buffer , Int64 , Int16 , Int8 , Kernel , Linkage , Vector16 , GenericPointer , Groups , Float16 , Float64 , AtomicFloat32AddEXT , ExpectAssumeKHR , SubgroupDispatch , VectorComputeINTEL , VectorAnyINTEL ], [SPV_EXT_shader_atomic_float_add , SPV_KHR_expect_assume , SPV_INTEL_vector_compute ]>, api =OpenCL , #spirv.resource_limits <>>} {
339+ gpu.module @test_kernel {
345340 gpu.func @transpose (%a: memref <512 x8 xf32 >, %b: memref <8 x512 xf32 >) kernel attributes {VectorComputeFunctionINTEL , spirv.entry_point_abi = #spirv.entry_point_abi <>} {
346341
347342 %c32 = arith.constant 32 : index
@@ -355,4 +350,100 @@ gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<
355350 xegpu.store_nd %2 , %3: vector <8 x32 xf32 >, !xegpu.tensor_desc <8 x32 xf32 , #t >
356351 gpu.return
357352 }
358- }
353+ }
354+
355+ // -----
356+ gpu.module @test_kernel {
357+ // CHECK-LABEL: test_prefetch_load_store_update
358+ // CHECK-SAME: [[arg0:%.+]]: ui64
359+ // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
360+ // CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
361+ // CHECK-COUNT-2: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xindex>
362+ // CHECK-COUNT-2: xegpu.load {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
363+ // CHECK-COUNT-2: xegpu.store {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>
364+
365+ gpu.func @test_prefetch_load_store_update (%src: ui64 ) {
366+
367+ %cst = arith.constant dense <[
368+ 0 , 8 , 16 , 24 , 32 , 40 , 48 , 56 ,
369+ 64 , 72 , 80 , 88 , 96 , 104 , 112 , 120 ,
370+ 128 , 136 , 144 , 152 , 160 , 168 , 176 , 184 ,
371+ 192 , 200 , 208 , 216 , 224 , 232 , 240 , 248
372+ ]> : vector <32 xindex >
373+
374+ %tdesc = xegpu.create_tdesc %src , %cst : ui64 , vector <32 xindex > -> !xegpu.tensor_desc <32 xf32 , #xegpu.scatter_tdesc_attr <>, #xegpu.layout <inst_data = [16 ]>>
375+ xegpu.prefetch %tdesc: !xegpu.tensor_desc <32 xf32 , #xegpu.scatter_tdesc_attr <>, #xegpu.layout <inst_data = [16 ]>>
376+
377+ %delta = arith.constant dense <[
378+ 32 , 32 , 32 , 32 , 32 , 32 , 32 , 32 ,
379+ 32 , 32 , 32 , 32 , 32 , 32 , 32 , 64 ,
380+ 128 , 128 , 128 , 128 , 128 , 128 , 128 , 128 ,
381+ 128 , 128 , 128 , 128 , 128 , 128 , 128 , 256
382+ ]> : vector <32 xindex >
383+ %new_tdesc = xegpu.update_offset %tdesc , %delta
384+ : !xegpu.tensor_desc <32 xf32 , #xegpu.scatter_tdesc_attr <>, #xegpu.layout <inst_data = [16 ]>>, vector <32 xindex >
385+
386+ %c17 = arith.constant 17 : index
387+ %mask = vector.create_mask %c17: vector <32 xi1 >
388+
389+ %ld_vec = xegpu.load %new_tdesc , %mask: !xegpu.tensor_desc <32 xf32 , #xegpu.scatter_tdesc_attr <>, #xegpu.layout <inst_data = [16 ]>>, vector <32 xi1 > -> vector <32 xf32 >
390+
391+ %st_vec = arith.addf %ld_vec , %ld_vec : vector <32 xf32 >
392+ xegpu.store %st_vec , %tdesc , %mask:
393+ vector <32 xf32 >,
394+ !xegpu.tensor_desc <32 xf32 , #xegpu.scatter_tdesc_attr <>, #xegpu.layout <inst_data = [16 ]>>,
395+ vector <32 xi1 >
396+
397+ gpu.return
398+ }
399+
400+ }
401+
402+ // -----
403+
404+ gpu.module @test_kernel {
405+ // CHECK-LABEL: test_prefetch_load_store_update_chunk
406+ // CHECK-SAME: [[arg0:%.+]]: ui64
407+ // CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
408+ // CHECK-COUNT-4: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
409+ // CHECK-COUNT-4: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xindex>
410+ // CHECK-COUNT-4: xegpu.load {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xi1> -> vector<2x16xf32>
411+ // CHECK-COUNT-4: xegpu.store {{.*}} : vector<2x16xf32>, !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xi1>
412+
413+ gpu.func @test_prefetch_load_store_update_chunk (%src: ui64 ) {
414+
415+ %cst = arith.constant dense <[
416+ 0 , 8 , 16 , 24 , 32 , 40 , 48 , 56 ,
417+ 64 , 72 , 80 , 88 , 96 , 104 , 112 , 120 ,
418+ 128 , 136 , 144 , 152 , 160 , 168 , 176 , 184 ,
419+ 192 , 200 , 208 , 216 , 224 , 232 , 240 , 248
420+ ]> : vector <32 xindex >
421+
422+ %tdesc = xegpu.create_tdesc %src , %cst : ui64 , vector <32 xindex > -> !xegpu.tensor_desc <32 x4 xf32 , #xegpu.scatter_tdesc_attr <chunk_size =4 >, #xegpu.layout <inst_data = [16 , 2 ]>>
423+ xegpu.prefetch %tdesc: !xegpu.tensor_desc <32 x4 xf32 , #xegpu.scatter_tdesc_attr <chunk_size =4 >, #xegpu.layout <inst_data = [16 , 2 ]>>
424+
425+ %delta = arith.constant dense <[
426+ 32 , 32 , 32 , 32 , 32 , 32 , 32 , 32 ,
427+ 32 , 32 , 32 , 32 , 32 , 32 , 32 , 64 ,
428+ 128 , 128 , 128 , 128 , 128 , 128 , 128 , 128 ,
429+ 128 , 128 , 128 , 128 , 128 , 128 , 128 , 256
430+ ]> : vector <32 xindex >
431+ %new_tdesc = xegpu.update_offset %tdesc , %delta
432+ : !xegpu.tensor_desc <32 x4 xf32 , #xegpu.scatter_tdesc_attr <chunk_size =4 >, #xegpu.layout <inst_data = [16 , 2 ]>>, vector <32 xindex >
433+
434+ %c17 = arith.constant 17 : index
435+ %mask = vector.create_mask %c17: vector <32 xi1 >
436+
437+ %ld_vec = xegpu.load %new_tdesc , %mask <{l1_hint = #xegpu.cache_hint <cached >, l2_hint = #xegpu.cache_hint <uncached >, transpose }>: !xegpu.tensor_desc <32 x4 xf32 , #xegpu.scatter_tdesc_attr <chunk_size =4 >, #xegpu.layout <inst_data = [16 , 2 ]>>, vector <32 xi1 > -> vector <4 x32 xf32 >
438+
439+ %st_vec = arith.addf %ld_vec , %ld_vec : vector <4 x32 xf32 >
440+ xegpu.store %st_vec , %tdesc , %mask <{l1_hint = #xegpu.cache_hint <cached >, l2_hint = #xegpu.cache_hint <uncached >, transpose }>:
441+ vector <4 x32 xf32 >,
442+ !xegpu.tensor_desc <32 x4 xf32 , #xegpu.scatter_tdesc_attr <chunk_size =4 >, #xegpu.layout <inst_data = [16 , 2 ]>>,
443+ vector <32 xi1 >
444+
445+ gpu.return
446+ }
447+ }
448+
449+
0 commit comments