Skip to content

Commit ec53dbe

Browse files
Max191keshavvinayak01
authored andcommitted
[LLVMGPU] Support map_scatter in LLVMGPUVectorDistribute pipeline (iree-org#21595)
Adds the vectorization and lowering passes for `iree_linalg_ext.map_scatter`, and a pipeline test for the LLVMGPUVectorDistribute pipeline. Signed-off-by: Max Dawkins <[email protected]> Signed-off-by: keshavvinayak01 <[email protected]>
1 parent b8f7418 commit ec53dbe

File tree

2 files changed

+60
-0
lines changed

2 files changed

+60
-0
lines changed

compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -972,6 +972,8 @@ void addGPUVectorDistributePassPipeline(OpPassManager &funcPassManager,
972972
funcPassManager.addPass(createOptimizeTensorInsertExtractSlicesPass());
973973

974974
// Linalg -> Vector
975+
funcPassManager.addPass(
976+
IREE::LinalgExt::createVectorizeIREELinalgExtOpsPass());
975977
addGPUVectorizationPasses(funcPassManager, /*vectorizeCopies=*/true,
976978
/*enableMasking=*/true);
977979

@@ -992,6 +994,7 @@ void addGPUVectorDistributePassPipeline(OpPassManager &funcPassManager,
992994

993995
// Vector SIMD -> Vector SIMT
994996
funcPassManager.addPass(createLLVMGPUVectorDistributePass());
997+
funcPassManager.addPass(IREE::LinalgExt::createDecomposeMapScatterPass());
995998
funcPassManager.addPass(createCanonicalizerPass());
996999
funcPassManager.addPass(createCSEPass());
9971000

compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx942.mlir

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,3 +1197,60 @@ hal.executable private @matvec_dispatch_0 {
11971197
// CHECK: scf.forall ({{.*}}) = (0, 0) to (32000, 2) step (16, 1)
11981198
// CHECK-COUNT-16: gpu.subgroup_reduce add {{.*}} cluster(size = 64) : (f32) -> f32
11991199
// CHECK-COUNT-16: gpu.subgroup_reduce add {{.*}} cluster(size = 2) : (f32) -> f32
1200+
1201+
// -----
1202+
1203+
#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 128], promote_operands = [0, 1], mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 2, subgroup_n_count = 2}>
1204+
#translation = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>}>
1205+
1206+
#pipeline_layout = #hal.pipeline.layout<bindings = [
1207+
#hal.pipeline.binding<storage_buffer>,
1208+
#hal.pipeline.binding<storage_buffer>,
1209+
#hal.pipeline.binding<storage_buffer>
1210+
]>
1211+
hal.executable public @matmul_map_scatter {
1212+
hal.executable.variant public @rocm target(<"rocm", "rocm-hsaco-fb">) {
1213+
hal.executable.export public @matmul_map_scatter layout(#pipeline_layout) count(%arg0: !hal.device, %arg1: index, %arg2: index) -> (index, index, index) {
1214+
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_dag_root %arg1, %arg2
1215+
hal.return %x, %y, %z : index, index, index
1216+
}
1217+
builtin.module {
1218+
func.func @matmul_map_scatter() attributes {translation_info = #translation} {
1219+
%true = arith.constant true
1220+
%cst = arith.constant 0.000000e+00 : f32
1221+
%c0 = arith.constant 0 : index
1222+
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : memref<256x256xf16, #hal.descriptor_type<storage_buffer>>
1223+
%1 = amdgpu.fat_raw_buffer_cast %0 resetOffset : memref<256x256xf16, #hal.descriptor_type<storage_buffer>> to memref<256x256xf16, #amdgpu.address_space<fat_raw_buffer>>
1224+
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : memref<256x256xf16, #hal.descriptor_type<storage_buffer>>
1225+
%3 = amdgpu.fat_raw_buffer_cast %2 resetOffset : memref<256x256xf16, #hal.descriptor_type<storage_buffer>> to memref<256x256xf16, #amdgpu.address_space<fat_raw_buffer>>
1226+
%4 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : memref<2x16x8x4x4x4x4xf32, #hal.descriptor_type<storage_buffer>>
1227+
%5 = amdgpu.fat_raw_buffer_cast %4 resetOffset : memref<2x16x8x4x4x4x4xf32, #hal.descriptor_type<storage_buffer>> to memref<2x16x8x4x4x4x4xf32, #amdgpu.address_space<fat_raw_buffer>>
1228+
%6 = iree_codegen.load_from_buffer %1 : memref<256x256xf16, #amdgpu.address_space<fat_raw_buffer>> -> tensor<256x256xf16>
1229+
%7 = iree_codegen.load_from_buffer %3 : memref<256x256xf16, #amdgpu.address_space<fat_raw_buffer>> -> tensor<256x256xf16>
1230+
%8 = tensor.empty() : tensor<256x256xf32>
1231+
%9 = linalg.fill ins(%cst : f32) outs(%8 : tensor<256x256xf32>) -> tensor<256x256xf32>
1232+
%10 = linalg.matmul {lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, promote_operands = [0, 1], reduction = [0, 0, 128], subgroup_m_count = 2 : i64, subgroup_n_count = 2 : i64, workgroup = [64, 64, 0]}>} ins(%6, %7 : tensor<256x256xf16>, tensor<256x256xf16>) outs(%9 : tensor<256x256xf32>) -> tensor<256x256xf32>
1233+
%11 = tensor.empty() : tensor<2x16x8x4x4x4x4xf32>
1234+
%12 = iree_linalg_ext.map_scatter %10 into %11 {
1235+
^bb0(%arg0: index, %arg1: index):
1236+
%13:2 = affine.delinearize_index %arg0 into (2, 128) : index, index
1237+
%14:2 = affine.delinearize_index %arg1 into (16, 16) : index, index
1238+
%15:3 = affine.delinearize_index %13#1 into (4, 8, 4) : index, index, index
1239+
%16:2 = affine.delinearize_index %14#1 into (4, 4) : index, index
1240+
iree_linalg_ext.yield %13#0, %14#0, %15#1, %16#1, %15#0, %15#2, %16#0, %true : index, index, index, index, index, index, index, i1
1241+
} : tensor<256x256xf32> into tensor<2x16x8x4x4x4x4xf32> -> tensor<2x16x8x4x4x4x4xf32>
1242+
iree_codegen.store_to_buffer %12, %5 : tensor<2x16x8x4x4x4x4xf32> into memref<2x16x8x4x4x4x4xf32, #amdgpu.address_space<fat_raw_buffer>>
1243+
return
1244+
}
1245+
}
1246+
}
1247+
}
1248+
// CHECK-LABEL: func.func @matmul_map_scatter()
1249+
// CHECK: %[[OUTPUT_BINDING:.+]] = hal.interface.binding.subspan{{.*}} binding(2)
1250+
// CHECK: %[[OUTPUT_BINDING_ALIGNED:.+]] = memref.assume_alignment %[[OUTPUT_BINDING]]
1251+
// CHECK: %[[OUTPUT_BUFFER:.+]] = amdgpu.fat_raw_buffer_cast %[[OUTPUT_BINDING_ALIGNED]]
1252+
// CHECK: %[[FOR_RESULT:.+]] = scf.for {{.*}} = %c0 to %c256 step %c128 iter_args({{.*}}) -> (vector<2x2x1x1x4x1xf32>)
1253+
// CHECK-COUNT-32: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>
1254+
// CHECK: scf.yield %{{.+}} : vector<2x2x1x1x4x1xf32>
1255+
// CHECK: %[[FLAT_OUTPUT_BUFFER:.+]] = memref.collapse_shape %[[OUTPUT_BUFFER]]
1256+
// CHECK-COUNT-4: vector.scatter %[[FLAT_OUTPUT_BUFFER]]{{.*}} : memref<65536xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xindex>, vector<4xi1>, vector<4xf32>

0 commit comments

Comments
 (0)