-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[mlir][xegpu] Add SIMT distribution support for GEMM transpose B case. #155517
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…r_and_SliceAttr' into vector_bitcast_distr
// communication. So each lane must own the required number of elements to | ||
// perform the bitcast locally without cross-lane communication. | ||
int outInnerBitsPerLane = outData[rank - 1] * outElemTyBitWidth; | ||
if (outInnerBitsPerLane < inElemTyBitWidth) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check the condition
srcInnerBitsPerLane = inElemTypeBitWidth x sourceLayout.getLaneData
if (outInnerBitsPerLane != srcInnerBitsPerLane)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought about this again. sourceLayout.getLaneData
is not available to us because we are trying to decide this here. I think we can only detect narrowing case only.
Widening case will always be valid because at this point if result already have a valid layout. Otherwise it means that result was not assigned a correct layout. That must be concern of the layout conflict maybe.
In any case, I added a check to verify if the result layout is valid and can be distributed to lanes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. I would move the check after the sourceLaneData is assigned. See comments below also.
shapeCast.emitWarning("Expecting result type to be 1D or 2D vector."); | ||
return; | ||
} | ||
// For 2D -> 2D shape cast, propagate the result layout to the source. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
consider the restriction for now:
- same rank shape cast not allowed,
- always expand the dim not squeeze the dim,
- The new dims must be 1, and the original dims must not change
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed I also added this condition for now.
- Result layout can not be a slice layout and it must have same rank as result.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Usually smaller PRs make reviews go faster but I'll bite 😉
Overall logic looks good, only minor comments.
for (int64_t idx : permutation) { | ||
newLayout.layout.push_back(laneLayout.layout[idx]); | ||
newData.layout.push_back(laneData.layout[idx]); | ||
laneLayout.push_back(static_cast<int32_t>(getLaneLayout()[idx])); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about add one more utilit to layout attribute, like getTransposedLayout(), so that it can be reused by sg_layout, or lane_layout.
Potentially, the isTransposeOf can be simplified to doing a transpose of input and compare whether they are same?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agree. I will add this in a separate PR and clean up.
func.func @vector_shape_cast_2d_to_1d_dim0_distributed(%arg0: !xegpu.tensor_desc<16x1xf16>, %arg1: !xegpu.tensor_desc<16xf16>) { | ||
%c0 = arith.constant 0 : index | ||
%3 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<16x1xf16> -> vector<16x1xf16> | ||
%2 = vector.shape_cast %3 : vector<16x1xf16> to vector<16xf16> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems contradict with the documentation
2) Shape cast must always expand the rank (e.g. 1D -> 2D).
Not sure why the code is passing. Maybe I missed something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry. I forgot to remove this test (CI was failing because of it). I removed this tests now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shape cast must always expand the rank (e.g. 1D -> 2D).
If you refer to vector.shape_cast
, a cast must preserve the same number of elements. Shape's rank can be freely changed up or down.
The two cases looked valid, it'd be good to understand why they failed.
If they can't be distributed, I'd leave them in as negative examples.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@adam-smnk The restriction is there because we do not expect (for now) any narrowing shape casts. Shape cast is currently used to make the vector 2D after a 2D -> 1D reduction.
Adding back the tests as negative examples for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my bad. pass is designed to fail if we can not assign a proper layout to ops. So I can not add the negative example in the same file AFAIK.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, then it's sth rethink if it impacts testing.
A separate test file would be fine as this one's already pretty large. Not sure if verify-diagnostics
can also test pass failures. TBD
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if verify-diagnostics can also test pass failures.
I think it can. challenge is doing it in same file. I did not find any examples. But I will give a try.
return; | ||
} | ||
// Decide lane data based on whether the bitcast is narrowing or widening. | ||
int64_t innerMostLaneData = isNarrowing ? outData[rank - 1] / bitCastRatio |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For narrowing bitcast, innerMostLaneData = outData[rank - 1] * bitCastRatio, instead of / bitCastRatio?
Put a TODO here?: check the layout conflict case here if ( innerMostLaneData * inElemTyBitWidth != outInnerBitsPerLane ).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For narrowing bitcast, innerMostLaneData = outData[rank - 1] * bitCastRatio, instead of / bitCastRatio?
This is because in narrowing case source had higher bitwidth (e.g f32 -> f16)
Put a TODO here?: check the layout conflict case here if ( innerMostLaneData * inElemTyBitWidth != outInnerBitsPerLane ).
This is not required. At this point of layout propagation result layout is already a valid layout. We chose innerMostLaneData
such that innerMostLaneData * inElemTyBitWidth == outInnerBitsPerLane
.
// communication. So each lane must own the required number of elements to | ||
// perform the bitcast locally without cross-lane communication. | ||
int outInnerBitsPerLane = outData[rank - 1] * outElemTyBitWidth; | ||
if (outInnerBitsPerLane < inElemTyBitWidth) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. I would move the check after the sourceLaneData is assigned. See comments below also.
@adam-smnk Can you take another look and/or approve? :-) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This PR adds the features needed for supporting the GEMM with transpose B case.
Summary of changes.
1). Add distribution logic for
vector.bitcast
,vector.transpose
andmemref.extract_aligned_pointer_as_index
cases.2). Add layout propagation support for
vector.shape_cast
,vector.broadcast
andvector.bitcast
3). Incorporate slice attribute and
DistributeLayoutAttr
interface with the core logic in layout prop.