Skip to content

Commit 65a52a9

Browse files
authored
At the beginning of emulate narrow type, flatten incoming memrefs (iree-org#21910)
As a first step to simplify emulation mechanism, make sure we are emulating only linearized memrefs.
1 parent be510b6 commit 65a52a9

File tree

2 files changed

+26
-0
lines changed

2 files changed

+26
-0
lines changed

compiler/src/iree/compiler/Codegen/Common/EmulateNarrowType.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,11 @@ LogicalResult emulateNarrowType(
616616
affine::AffineDialect, IREE::HAL::HALDialect>(opLegalCallback);
617617

618618
RewritePatternSet patterns(ctx);
619+
620+
// Try to flatten memrefs as a prerequiste for narrow type emulation,
621+
// so we can have simplified checks in the emulation patterns.
622+
memref::populateFlattenMemrefsPatterns(patterns);
623+
619624
patterns.insert<IREEConvertVectorStore>(ctx, /*disableAtomicRMW=*/false,
620625
/*benefit=*/100);
621626
arith::populateArithNarrowTypeEmulationPatterns(typeConverter, patterns);

compiler/src/iree/compiler/Codegen/Common/test/emulate_narrow_type.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,24 @@ func.func @broadcast_extui() -> vector<1x1x64xi32> {
4646
// CHECK-LABEL: func @broadcast_extui()
4747
// CHECK-NOT: vector.bitcast
4848
// CHECK: vector.interleave
49+
50+
// -----
51+
52+
#pipeline_layout = #hal.pipeline.layout<bindings = [
53+
#hal.pipeline.binding<storage_buffer>
54+
]>
55+
func.func @memref_load_2d_i4() -> i4 {
56+
%c0 = arith.constant 0 : index
57+
%c1 = arith.constant 1 : index
58+
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : memref<16x32xi4>
59+
%v = memref.load %0[%c1, %c0] : memref<16x32xi4>
60+
return %v : i4
61+
}
62+
63+
// CHECK-LABEL: func.func @memref_load_2d_i4()
64+
// CHECK: %[[C16:.*]] = arith.constant 16 : index
65+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
66+
// CHECK: %[[SUBSPAN:.*]] = hal.interface.binding.subspan {{.*}} : memref<256xi8>
67+
// CHECK: %[[LOAD:.*]] = memref.load %[[SUBSPAN]][%[[C16]]] : memref<256xi8>
68+
// CHECK: %[[TRUNC:.*]] = arith.trunci %[[LOAD]] : i8 to i4
69+
// CHECK: return %[[TRUNC]] : i4

0 commit comments

Comments
 (0)