Skip to content

Commit bf7c490

Browse files
[mlir][Vector] Add a rewrite pattern for better low-precision bitcast… (llvm#66387)
…(trunci) expansion This revision adds a rewrite for sequences of vector `bitcast(trunci)` to use a more efficient sequence of vector operations comprising `shuffle` and `bitwise` ops. Such patterns appear naturally when writing quantization / dequantization functionality with the vector dialect. The rewrite performs a simple enumeration of each of the bits in the result vector and determines its provenance in the pre-trunci vector. The enumeration is used to generate the proper sequence of `shuffle`, `andi`, `ori` followed by an optional final `trunci`/`extui`. The rewrite currently only applies to 1-D non-scalable vectors and bails out if the final vector element type is not a multiple of 8. This is a failsafe heuristic determined empirically: if the resulting type is not an even number of bytes, further complexities arise that are not improved by this pattern: the heavy lifting still needs to be done by LLVM.
1 parent b8f6443 commit bf7c490

File tree

7 files changed

+658
-5
lines changed

7 files changed

+658
-5
lines changed

mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,19 @@ def ApplyLowerTransposePatternsOp : Op<Transform_Dialect,
292292
}];
293293
}
294294

295+
def ApplyRewriteNarrowTypePatternsOp : Op<Transform_Dialect,
296+
"apply_patterns.vector.rewrite_narrow_types",
297+
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
298+
let description = [{
299+
Indicates that vector narrow rewrite operations should be applied.
300+
301+
This is usually a late step that is run after bufferization as part of the
302+
process of lowering to e.g. LLVM or NVVM.
303+
}];
304+
305+
let assemblyFormat = "attr-dict";
306+
}
307+
295308
def ApplySplitTransferFullPartialPatternsOp : Op<Transform_Dialect,
296309
"apply_patterns.vector.split_transfer_full_partial",
297310
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {

mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class RewritePatternSet;
2424

2525
namespace arith {
2626
class NarrowTypeEmulationConverter;
27+
class TruncIOp;
2728
} // namespace arith
2829

2930
namespace vector {
@@ -143,7 +144,7 @@ void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
143144

144145
/// Patterns that remove redundant vector broadcasts.
145146
void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns,
146-
PatternBenefit benefit = 1);
147+
PatternBenefit benefit = 1);
147148

148149
/// Populate `patterns` with the following patterns.
149150
///
@@ -301,6 +302,18 @@ void populateVectorNarrowTypeEmulationPatterns(
301302
arith::NarrowTypeEmulationConverter &typeConverter,
302303
RewritePatternSet &patterns);
303304

305+
/// Rewrite a vector `bitcast(trunci)` to use a more efficient sequence of
306+
/// vector operations comprising `shuffle` and `bitwise` ops.
307+
FailureOr<Value> rewriteBitCastOfTruncI(RewriterBase &rewriter,
308+
vector::BitCastOp bitCastOp,
309+
arith::TruncIOp truncOp,
310+
vector::BroadcastOp maybeBroadcastOp);
311+
312+
/// Appends patterns for rewriting vector operations over narrow types with
313+
/// ops over wider types.
314+
void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
315+
PatternBenefit benefit = 1);
316+
304317
} // namespace vector
305318
} // namespace mlir
306319

mlir/include/mlir/IR/BuiltinTypes.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,16 @@ class VectorType::Builder {
357357
return *this;
358358
}
359359

360+
/// Set a dim in shape @pos to val.
361+
Builder &setDim(unsigned pos, int64_t val) {
362+
if (storage.empty())
363+
storage.append(shape.begin(), shape.end());
364+
assert(pos < storage.size() && "overflow");
365+
storage[pos] = val;
366+
shape = {storage.data(), storage.size()};
367+
return *this;
368+
}
369+
360370
operator VectorType() {
361371
return VectorType::get(shape, elementType, scalableDims);
362372
}

mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,11 @@ void transform::ApplyLowerTransposePatternsOp::populatePatterns(
159159
}
160160
}
161161

162+
void transform::ApplyRewriteNarrowTypePatternsOp::populatePatterns(
163+
RewritePatternSet &patterns) {
164+
populateVectorNarrowTypeRewritePatterns(patterns);
165+
}
166+
162167
void transform::ApplySplitTransferFullPartialPatternsOp::populatePatterns(
163168
RewritePatternSet &patterns) {
164169
vector::VectorTransformsOptions vectorTransformOptions;

0 commit comments

Comments
 (0)