-
Notifications
You must be signed in to change notification settings - Fork 25
WIP: priority-based index propagation #734
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
and test for matrix add where it is inferred from writes Signed-off-by: Alex Zinenko <[email protected]>
| // TODO: pywave just ignores this not sure if we want to, including the | ||
| // case below where there may be zero constraints. Interestingly, it | ||
| // asserts if trailing dimensions are not found when computing the | ||
| // stride... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it safe to simply ignore the symbols for which there are no constraints when setting index sequences from write?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If no constraints are specified or the vector shape is not set to 0 (dimensions we don't want to expand), then the symbol either corresponds to the actual tensor dimension or is set dynamically in the kernel. I don't think we should ignore the symbol because it could be meaningful in the analysis.
| emitError() << "expected a single workgroup constraint for dimension " | ||
| << tensorType.getShape()[i] | ||
| << " used in the write operation without explicit " | ||
| "`elements_per_thread`"; | ||
| return failure(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto, but in absence of a workgroup constraint?
It feels like we need to set it to start=0, and likely size=1 and stride=1 but not sure
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the code does make some assumptions like that where it falls back to start = 0, size and stride of 1, but I think we shouldn't allow that and instead be more explicit.
| // TODO: in pywave, we always do `startExpr % threadsPerWave` where | ||
| // threadsPerWave == 1 for workgroup dims other than X, making it | ||
| // always zero. It mentions an assumption about the (64, 1, 1) thread | ||
| // shape, but it is unclear whether that assumption always holds. | ||
| // It looks like the intention for this was to express lane ID rather | ||
| // than thread ID, but it is unclear how it accounts for multiple | ||
| // wavefronts running in parallel. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment in the original source (
wave/wave_lang/kernel/wave/constraints.py
Lines 498 to 501 in 601ab68
| # We have an assumption that the thread dimensions in each wave is of shape (64,1,1). | |
| # In cases other than dimension 0, we also calculate the modulus of thread_id with the | |
| # number of threads in that dimension to prevent double counting of thread ID in thread | |
| # independent index. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comes up in the SIMT context (no MMA, you can also see this in the example for the atomic case). If you look at the original code, what was happening was that because we dont have an MMA, the default pattern for SIMT is a thread linear pattern and so for the atomicAdd we were getting a dependence on x and y, even though that shouldn't be the case for the example. So this was a fix to handle that scenario. Will also tag @nithinsubbiah to add more context.
and test for matrix add where it is inferred from writes