Skip to content

Commit 3b62d23

Browse files
authored
[CPU][DT] Add codegen support for broadcast/dequant -> matmul dispatch. (iree-org#21911)
The encoding materialization is moved before RematerializeParallelOps pass, because it can make matmul ops (in generic op form) have more than two input operands and the materialization pattern does not recognize the case. Instead of having the checks in the pass, the revision runs the materialization pass earlier, which is an easier and reasonable fix. It also adds LVMCPUTileToVectorSize pass the mmt4d pipeline, so the producer can be tiled to target vector size, which avoids large vectors and stack allocations. Fixes iree-org#21866 --------- Signed-off-by: hanhanW <[email protected]>
1 parent 3edd217 commit 3b62d23

File tree

2 files changed

+75
-4
lines changed

2 files changed

+75
-4
lines changed

compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,7 @@ void addMmt4dTilingExpertPassPipeline(
397397
funcPassManager.addPass(createLLVMCPUTileRootAndFuseInputOperandsPass(
398398
IREE::CPU::TilingLevel::VectorReductionTiles));
399399
funcPassManager.addPass(iree_compiler::createForallToForPass());
400+
funcPassManager.addPass(createLLVMCPUTileToVectorSizePass());
400401

401402
{
402403
GenericVectorizationPassOptions options;
@@ -663,12 +664,12 @@ void buildLLVMCPUCodegenConfigurationPassPipelineImpl(
663664
}
664665
modulePassManager.addPass(createMaterializeUserConfigsPass());
665666
FunctionLikeNest(modulePassManager)
667+
.addPass(createMaterializeDeviceEncodingPass)
668+
.addPass(createCPUPropagateDataLayoutPass)
666669
.addPass(createRematerializeParallelOpsPass)
667670
// TODO(#13888): This(createExpandF16OpToF32Pass()) pass is being added
668671
// way to late and should insted be be done during lowering to LLVM.
669672
.addPass(createExpandF16OpToF32Pass)
670-
.addPass(createMaterializeDeviceEncodingPass)
671-
.addPass(createCPUPropagateDataLayoutPass)
672673
.addPass(createConvertAccGEMMToGEMMPass)
673674
// TODO: Remove the following pass the plumb support for
674675
// #hal.descriptor_type memory space through the stack.

compiler/src/iree/compiler/Codegen/LLVMCPU/test/pipeline_tests.mlir

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -469,13 +469,21 @@ module {
469469
return
470470
}
471471
}
472+
// TODO(#21696): Respect the tile sizes for packed domain when we set the
473+
// lowering configs. The `alloca` op is generated because it performs additional
474+
// tiling on the generic op that targets vector sizes config. The `alloca` op is
475+
// not needed if we don't tile it at all, which means that it can be solved by
476+
// not setting the tile size for the packed dimensions.
477+
472478
// CHECK-LABEL: func.func @mmt4d_bias_relu
473-
// CHECK-NOT: memref.alloc
479+
// CHECK: memref.alloca() {alignment = 64 : i64} : memref<1x1x2x16xf32
474480
// CHECK: scf.forall
475481
// CHECK: scf.for
476482
// CHECK: vector.fma
477-
// CHECK: vector.insert
483+
// CHECK: vector.fma
484+
// CHECK: }
478485
// CHECK: arith.addf
486+
// CHECK: arith.maximumf
479487

480488
// -----
481489

@@ -683,3 +691,65 @@ func.func @matmul_accumulate_from_readonly(%M: index, %N: index, %K: index) attr
683691
// CHECK-LABEL: func.func @matmul_accumulate_from_readonly(
684692
// CHECK-NOT: memref.alloc
685693
// CHECK-NOT: linalg.generic
694+
695+
// -----
696+
697+
// Verifies that the backend can handle broadcast/dequant op followed by a
698+
// matmul with encodings. We only check if the ukernel op is generated or not.
699+
// The test ensures that there are no big vectors and stack allocations when it
700+
// succeeds.
701+
702+
#encoding = #iree_encoding.layout<[#iree_cpu.cpu_encoding_resolver<configuration = {encoding_info = {innerDimsPos = [0, 1], innerTileSizes = [16, 2], outerDimsPerm = [0, 1]}}>]>
703+
#encoding1 = #iree_encoding.layout<[#iree_cpu.cpu_encoding_resolver<configuration = {encoding_info = {innerDimsPos = [0], innerTileSizes = [2], outerDimsPerm = [0]}}>]>
704+
#encoding2 = #iree_encoding.layout<[#iree_cpu.cpu_encoding_resolver<configuration = {encoding_info = {innerDimsPos = [1, 0], innerTileSizes = [16, 2], outerDimsPerm = [1, 0]}}>]>
705+
#encoding3 = #iree_encoding.layout<[#iree_cpu.cpu_encoding_resolver<configuration = {encoding_info = {innerDimsPos = [0, 1], innerTileSizes = [16, 16], outerDimsPerm = [0, 1]}}>]>
706+
#executable_target_embedded_elf_x86_64 = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {cpu = "znver4", cpu_features = "+mmx,+popcnt,+sse,+sse2,+sse3,+ssse3,+sse4.1,+sse4.2,+avx,+avx2,+sse4a,+fma,+avx512f,+bmi,+bmi2,+aes,+pclmul,+avx512vl,+avx512bw,+avx512dq,+avx512cd,+avx512vbmi,+avx512ifma,+avx512vpopcntdq,+avx512vbmi2,+gfni,+vpclmulqdq,+avx512vnni,+avx512bitalg,+avx512bf16,+adx,+clflushopt,+clwb,+clzero,+cx16,+cx8,+f16c,+fsgsbase,+crc32,+invpcid,+rdpru,+sahf,+lzcnt,+movbe,+mwaitx,+x87,+pku,+prfchw,+rdpid,+rdrnd,+rdseed,+sha,+shstk,+vaes,+wbnoinvd,+xsave,+xsavec,+xsaveopt,+xsaves,+fxsr", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", iree.encoding.resolver = #iree_cpu.cpu_encoding_resolver<>, max_stack_allocation_size = 32768 : i64, native_vector_size = 64 : i64, target_triple = "x86_64-unknown-unknown-eabi-elf"}>
707+
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
708+
#map1 = affine_map<(d0, d1) -> (d0, d1)>
709+
#map2 = affine_map<(d0, d1, d2) -> (d2, d1)>
710+
#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
711+
#map4 = affine_map<(d0, d1) -> (d1)>
712+
#pipeline_layout = #hal.pipeline.layout<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>
713+
#encoding4 = #iree_encoding.encoding<operand_index = 1 : index, op_type = matmul, element_types = [i8, i8, i32], user_indexing_maps = [#map, #map2, #map3], iteration_sizes = [123, 789, 456]>
714+
#encoding5 = #iree_encoding.encoding<operand_index = 2 : index, op_type = matmul, element_types = [i8, i8, i32], user_indexing_maps = [#map, #map2, #map3], iteration_sizes = [123, 789, 456]>
715+
#encoding6 = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [i8, i8, i32], user_indexing_maps = [#map, #map2, #map3], iteration_sizes = [123, 789, 456]>
716+
#encoding7 = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [i8, i8, i32], user_indexing_maps = [[#map, #map1], #map2, #map3], iteration_sizes = [123, 789, 456]>
717+
#encoding8 = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [i8, i8, i32], user_indexing_maps = [[#map, #map4], #map2, #map3], iteration_sizes = [123, 789, 456]>
718+
module {
719+
func.func @dequant_lhs_matmul() attributes {hal.executable.target = #executable_target_embedded_elf_x86_64} {
720+
%c0 = arith.constant 0 : index
721+
%c29184 = arith.constant 29184 : index
722+
%c29440 = arith.constant 29440 : index
723+
%c394240 = arith.constant 394240 : index
724+
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !iree_tensor_ext.dispatch.tensor<readonly:tensor<123x456xi4, #encoding>>
725+
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c29184) flags("ReadOnly|Indirect") : !iree_tensor_ext.dispatch.tensor<readonly:tensor<456xi4, #encoding1>>
726+
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c29440) flags("ReadOnly|Indirect") : !iree_tensor_ext.dispatch.tensor<readonly:tensor<456x789xi8, #encoding2>>
727+
%3 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c394240) flags(Indirect) : !iree_tensor_ext.dispatch.tensor<readwrite:tensor<123x789xi32, #encoding3>>
728+
%4 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) flags(Indirect) : !iree_tensor_ext.dispatch.tensor<writeonly:tensor<123x789xi32>>
729+
%5 = iree_tensor_ext.dispatch.tensor.load %0, offsets = [0, 0], sizes = [123, 456], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<123x456xi4, #encoding>> -> tensor<123x456xi4, #encoding7>
730+
%6 = iree_tensor_ext.dispatch.tensor.load %1, offsets = [0], sizes = [456], strides = [1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<456xi4, #encoding1>> -> tensor<456xi4, #encoding8>
731+
%7 = iree_tensor_ext.dispatch.tensor.load %2, offsets = [0, 0], sizes = [456, 789], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<456x789xi8, #encoding2>> -> tensor<456x789xi8, #encoding4>
732+
%8 = iree_tensor_ext.dispatch.tensor.load %3, offsets = [0, 0], sizes = [123, 789], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readwrite:tensor<123x789xi32, #encoding3>> -> tensor<123x789xi32, #encoding5>
733+
%9 = tensor.empty() : tensor<123x456xi8, #encoding6>
734+
%10 = linalg.generic {indexing_maps = [#map1, #map4, #map1], iterator_types = ["parallel", "parallel"]} ins(%5, %6 : tensor<123x456xi4, #encoding7>, tensor<456xi4, #encoding8>) outs(%9 : tensor<123x456xi8, #encoding6>) {
735+
^bb0(%in: i4, %in_0: i4, %out: i8):
736+
%13 = arith.extui %in : i4 to i8
737+
%14 = arith.extsi %in_0 : i4 to i8
738+
%15 = arith.subi %13, %14 : i8
739+
linalg.yield %15 : i8
740+
} -> tensor<123x456xi8, #encoding6>
741+
%11 = linalg.generic {indexing_maps = [#map, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} ins(%10, %7 : tensor<123x456xi8, #encoding6>, tensor<456x789xi8, #encoding4>) outs(%8 : tensor<123x789xi32, #encoding5>) {
742+
^bb0(%in: i8, %in_0: i8, %out: i32):
743+
%13 = arith.extsi %in : i8 to i32
744+
%14 = arith.extsi %in_0 : i8 to i32
745+
%15 = arith.muli %13, %14 : i32
746+
%16 = arith.addi %out, %15 : i32
747+
linalg.yield %16 : i32
748+
} -> tensor<123x789xi32, #encoding5>
749+
%12 = iree_encoding.unset_encoding %11 : tensor<123x789xi32, #encoding5> -> tensor<123x789xi32>
750+
iree_tensor_ext.dispatch.tensor.store %12, %4, offsets = [0, 0], sizes = [123, 789], strides = [1, 1] : tensor<123x789xi32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<123x789xi32>>
751+
return
752+
}
753+
}
754+
// CHECK-LABEL: func.func @dequant_lhs_matmul(
755+
// CHECK: iree_codegen.ukernel.generic

0 commit comments

Comments
 (0)