@@ -1197,3 +1197,60 @@ hal.executable private @matvec_dispatch_0 {
11971197// CHECK: scf.forall ({{.*}}) = (0, 0) to (32000, 2) step (16, 1)
11981198// CHECK-COUNT-16: gpu.subgroup_reduce add {{.*}} cluster(size = 64) : (f32) -> f32
11991199// CHECK-COUNT-16: gpu.subgroup_reduce add {{.*}} cluster(size = 2) : (f32) -> f32
1200+
1201+ // -----
1202+
1203+ #config = #iree_gpu.lowering_config <{workgroup = [64 , 64 , 0 ], reduction = [0 , 0 , 128 ], promote_operands = [0 , 1 ], mma_kind = #iree_gpu.mma_layout <MFMA_F32_16x16x16_F16 >, subgroup_m_count = 2 , subgroup_n_count = 2 }>
1204+ #translation = #iree_codegen.translation_info <pipeline = LLVMGPUVectorDistribute workgroup_size = [256 , 1 , 1 ] subgroup_size = 64 , {gpu_pipeline_options = #iree_gpu.pipeline_options <prefetch_shared_memory = true , no_reduce_shared_memory_bank_conflicts = false >}>
1205+
1206+ #pipeline_layout = #hal.pipeline.layout <bindings = [
1207+ #hal.pipeline.binding <storage_buffer >,
1208+ #hal.pipeline.binding <storage_buffer >,
1209+ #hal.pipeline.binding <storage_buffer >
1210+ ]>
1211+ hal.executable public @matmul_map_scatter {
1212+ hal.executable.variant public @rocm target (<" rocm" , " rocm-hsaco-fb" >) {
1213+ hal.executable.export public @matmul_map_scatter layout (#pipeline_layout ) count (%arg0: !hal.device , %arg1: index , %arg2: index ) -> (index , index , index ) {
1214+ %x , %y , %z = iree_tensor_ext.dispatch.workgroup_count_from_dag_root %arg1 , %arg2
1215+ hal.return %x , %y , %z : index , index , index
1216+ }
1217+ builtin.module {
1218+ func.func @matmul_map_scatter () attributes {translation_info = #translation } {
1219+ %true = arith.constant true
1220+ %cst = arith.constant 0.000000e+00 : f32
1221+ %c0 = arith.constant 0 : index
1222+ %0 = hal.interface.binding.subspan layout (#pipeline_layout ) binding (0 ) alignment (64 ) offset (%c0 ) flags (ReadOnly ) : memref <256 x256 xf16 , #hal.descriptor_type <storage_buffer >>
1223+ %1 = amdgpu.fat_raw_buffer_cast %0 resetOffset : memref <256 x256 xf16 , #hal.descriptor_type <storage_buffer >> to memref <256 x256 xf16 , #amdgpu.address_space <fat_raw_buffer >>
1224+ %2 = hal.interface.binding.subspan layout (#pipeline_layout ) binding (1 ) alignment (64 ) offset (%c0 ) flags (ReadOnly ) : memref <256 x256 xf16 , #hal.descriptor_type <storage_buffer >>
1225+ %3 = amdgpu.fat_raw_buffer_cast %2 resetOffset : memref <256 x256 xf16 , #hal.descriptor_type <storage_buffer >> to memref <256 x256 xf16 , #amdgpu.address_space <fat_raw_buffer >>
1226+ %4 = hal.interface.binding.subspan layout (#pipeline_layout ) binding (2 ) alignment (64 ) offset (%c0 ) : memref <2 x16 x8 x4 x4 x4 x4 xf32 , #hal.descriptor_type <storage_buffer >>
1227+ %5 = amdgpu.fat_raw_buffer_cast %4 resetOffset : memref <2 x16 x8 x4 x4 x4 x4 xf32 , #hal.descriptor_type <storage_buffer >> to memref <2 x16 x8 x4 x4 x4 x4 xf32 , #amdgpu.address_space <fat_raw_buffer >>
1228+ %6 = iree_codegen.load_from_buffer %1 : memref <256 x256 xf16 , #amdgpu.address_space <fat_raw_buffer >> -> tensor <256 x256 xf16 >
1229+ %7 = iree_codegen.load_from_buffer %3 : memref <256 x256 xf16 , #amdgpu.address_space <fat_raw_buffer >> -> tensor <256 x256 xf16 >
1230+ %8 = tensor.empty () : tensor <256 x256 xf32 >
1231+ %9 = linalg.fill ins (%cst : f32 ) outs (%8 : tensor <256 x256 xf32 >) -> tensor <256 x256 xf32 >
1232+ %10 = linalg.matmul {lowering_config = #iree_gpu.lowering_config <{mma_kind = #iree_gpu.mma_layout <MFMA_F32_16x16x16_F16 >, promote_operands = [0 , 1 ], reduction = [0 , 0 , 128 ], subgroup_m_count = 2 : i64 , subgroup_n_count = 2 : i64 , workgroup = [64 , 64 , 0 ]}>} ins (%6 , %7 : tensor <256 x256 xf16 >, tensor <256 x256 xf16 >) outs (%9 : tensor <256 x256 xf32 >) -> tensor <256 x256 xf32 >
1233+ %11 = tensor.empty () : tensor <2 x16 x8 x4 x4 x4 x4 xf32 >
1234+ %12 = iree_linalg_ext.map_scatter %10 into %11 {
1235+ ^bb0 (%arg0: index , %arg1: index ):
1236+ %13:2 = affine.delinearize_index %arg0 into (2 , 128 ) : index , index
1237+ %14:2 = affine.delinearize_index %arg1 into (16 , 16 ) : index , index
1238+ %15:3 = affine.delinearize_index %13#1 into (4 , 8 , 4 ) : index , index , index
1239+ %16:2 = affine.delinearize_index %14#1 into (4 , 4 ) : index , index
1240+ iree_linalg_ext.yield %13#0 , %14#0 , %15#1 , %16#1 , %15#0 , %15#2 , %16#0 , %true : index , index , index , index , index , index , index , i1
1241+ } : tensor <256 x256 xf32 > into tensor <2 x16 x8 x4 x4 x4 x4 xf32 > -> tensor <2 x16 x8 x4 x4 x4 x4 xf32 >
1242+ iree_codegen.store_to_buffer %12 , %5 : tensor <2 x16 x8 x4 x4 x4 x4 xf32 > into memref <2 x16 x8 x4 x4 x4 x4 xf32 , #amdgpu.address_space <fat_raw_buffer >>
1243+ return
1244+ }
1245+ }
1246+ }
1247+ }
1248+ // CHECK-LABEL: func.func @matmul_map_scatter()
1249+ // CHECK: %[[OUTPUT_BINDING:.+]] = hal.interface.binding.subspan{{.*}} binding(2)
1250+ // CHECK: %[[OUTPUT_BINDING_ALIGNED:.+]] = memref.assume_alignment %[[OUTPUT_BINDING]]
1251+ // CHECK: %[[OUTPUT_BUFFER:.+]] = amdgpu.fat_raw_buffer_cast %[[OUTPUT_BINDING_ALIGNED]]
1252+ // CHECK: %[[FOR_RESULT:.+]] = scf.for {{.*}} = %c0 to %c256 step %c128 iter_args({{.*}}) -> (vector<2x2x1x1x4x1xf32>)
1253+ // CHECK-COUNT-32: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>
1254+ // CHECK: scf.yield %{{.+}} : vector<2x2x1x1x4x1xf32>
1255+ // CHECK: %[[FLAT_OUTPUT_BUFFER:.+]] = memref.collapse_shape %[[OUTPUT_BUFFER]]
1256+ // CHECK-COUNT-4: vector.scatter %[[FLAT_OUTPUT_BUFFER]]{{.*}} : memref<65536xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xindex>, vector<4xi1>, vector<4xf32>
0 commit comments