Skip to content

Conversation

@ftynse
Copy link
Contributor

@ftynse ftynse commented Jan 14, 2026

and test for matrix add where it is inferred from writes

and test for matrix add where it is inferred from writes

Signed-off-by: Alex Zinenko <[email protected]>
Comment on lines +1757 to +1760
// TODO: pywave just ignores this not sure if we want to, including the
// case below where there may be zero constraints. Interestingly, it
// asserts if trailing dimensions are not found when computing the
// stride...
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it safe to simply ignore the symbols for which there are no constraints when setting index sequences from write?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If no constraints are specified or the vector shape is not set to 0 (dimensions we don't want to expand), then the symbol either corresponds to the actual tensor dimension or is set dynamically in the kernel. I don't think we should ignore the symbol because it could be meaningful in the analysis.

Comment on lines +1782 to +1786
emitError() << "expected a single workgroup constraint for dimension "
<< tensorType.getShape()[i]
<< " used in the write operation without explicit "
"`elements_per_thread`";
return failure();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto, but in absence of a workgroup constraint?

It feels like we need to set it to start=0, and likely size=1 and stride=1 but not sure

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the code does make some assumptions like that where it falls back to start = 0, size and stride of 1, but I think we shouldn't allow that and instead be more explicit.

Comment on lines +1861 to +1867
// TODO: in pywave, we always do `startExpr % threadsPerWave` where
// threadsPerWave == 1 for workgroup dims other than X, making it
// always zero. It mentions an assumption about the (64, 1, 1) thread
// shape, but it is unclear whether that assumption always holds.
// It looks like the intention for this was to express lane ID rather
// than thread ID, but it is unclear how it accounts for multiple
// wavefronts running in parallel.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment in the original source (

# We have an assumption that the thread dimensions in each wave is of shape (64,1,1).
# In cases other than dimension 0, we also calculate the modulus of thread_id with the
# number of threads in that dimension to prevent double counting of thread ID in thread
# independent index.
) says something about preventing double counting of thread id, but I can't infer where and why it would be counted twice. The support for it was added in a commit for atomics, 4eeee9a, which is doesn't provide an explanation either

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comes up in the SIMT context (no MMA, you can also see this in the example for the atomic case). If you look at the original code, what was happening was that because we dont have an MMA, the default pattern for SIMT is a thread linear pattern and so for the atomicAdd we were getting a dependence on x and y, even though that shouldn't be the case for the example. So this was a fix to handle that scenario. Will also tag @nithinsubbiah to add more context.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants