-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[mlir] Rewrites for I2 to I8 signed and unsigned extension #121298
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 10 commits
908a6eb
e3c5c56
99b6785
303416a
cfe31bb
0313d89
3b1005d
b975051
1196049
106f8b7
7e25b9a
70ae38a
86e11c4
5c5396b
6114fcf
8abb46d
bf111da
44522e1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -1090,12 +1090,16 @@ static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter, | |||
| unsigned srcElemBitwidth = subByteVecType.getElementTypeBitWidth(); | ||||
| unsigned dstElemBitwidth = dstType.getElementTypeBitWidth(); | ||||
|
|
||||
| // Only {s}i4 -> (size_of({{s}i/f}) >= 8) are supported for now. | ||||
| if (srcElemBitwidth != 4 || dstElemBitwidth < 8 || | ||||
| (dstElemBitwidth % srcElemBitwidth) != 0) | ||||
| return rewriter.notifyMatchFailure(op, "Not a supported aligned case"); | ||||
| if (dstElemBitwidth < 8) | ||||
| return rewriter.notifyMatchFailure( | ||||
| op, "the bitwidth of dstType must be greater than or equal to 8"); | ||||
| if (dstElemBitwidth % srcElemBitwidth != 0) | ||||
| return rewriter.notifyMatchFailure(op, "unaligned cases are not supported"); | ||||
| if (srcElemBitwidth != 2 && srcElemBitwidth != 4) | ||||
| return rewriter.notifyMatchFailure( | ||||
| op, "only src bitwidth of 2 or 4 is supported at this moment"); | ||||
|
|
||||
| const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth; | ||||
| const int numSrcElemsPerDestElem = 8 / srcElemBitwidth; | ||||
|
||||
| if ((subByteVecType.getShape().back() % numSrcElemsPerDestElem) != 0) | ||||
| return rewriter.notifyMatchFailure( | ||||
| op, "Not an even number of i4 elements in trailing dim"); | ||||
|
|
@@ -1179,70 +1183,166 @@ Value BitCastRewriter::genericRewriteStep( | |||
| return runningResult; | ||||
| } | ||||
|
|
||||
| /// Rewrite the i4 -> i8 signed extension into a sequence of shuffles and | ||||
| /// bitwise ops that take advantage of high-level information to avoid leaving | ||||
| /// LLVM to scramble with peephole optimizations. | ||||
| static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc, | ||||
| Value srcValue) { | ||||
| VectorType srcVecType = cast<VectorType>(srcValue.getType()); | ||||
| assert(srcVecType.getElementType().isSignlessInteger(4) && | ||||
| "Expected i4 type"); | ||||
| /// Bitcasts the aligned `subByteVec` vector to a vector of i8. | ||||
| /// Where aligned means it satisfies the alignedConversionPreconditions. | ||||
| /// | ||||
| /// Example: | ||||
| /// vector<16x16xi2> -> vector<16x2xi8> | ||||
| /// vector<16x16xi4> -> vector<16x4xi8> | ||||
ziereis marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||
| static Value bitcastSubByteVectorToI8(PatternRewriter &rewriter, Location loc, | ||||
| Value subByteVec) { | ||||
| auto srcVecType = cast<VectorType>(subByteVec.getType()); | ||||
| int64_t srcBitwidth = srcVecType.getElementType().getIntOrFloatBitWidth(); | ||||
| assert(8 % srcBitwidth == 0 && | ||||
| "Unsupported sub-byte type (not a divisor of i8)"); | ||||
| int64_t bitwidthFactor = 8 / srcBitwidth; | ||||
|
||||
| const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, there are some other places that refer to the same name but im not sure if they refer to the same thing
ziereis marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am a bit confused about the naming and this description.
IIUC, this method will:
- for every byte
binsrc(which is a vector of bytes), - extracts
numBitsstarting atbitIdx(let's call itinputVal), and - returns a byte matching the value encoded in
inputVal.
So this method is more like extractNBitsAndReturnAsByte?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- for every byte
binsrc(which is a vector of bytes),- extracts
numBitsstarting atbitIdx(let's call itinputVal), and- returns a byte matching the value encoded in
inputVal.
it will extract numBits for every byte of src at bitIdx and will return a vector of bytes, the resultType will always be the same as the srcType.
So for example lets say numBits is 4, it will treat the inputVal as a i4 and (sign)ext it to a i8 value.
im not sure about the name either, maybe extractNBitsAnd(Sign)ExtToI8 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
extractNbitsPerByteAndExtendToI8?
Am I correct that this method assumes that the src and dst element type is i8?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes
Uh oh!
There was an error while loading. Please reload this page.