-
Notifications
You must be signed in to change notification settings - Fork 25
[water] Add wave.broadcast operation definition #778
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: users/ftynse/reductions
Are you sure you want to change the base?
Conversation
ftynse
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need the explicit broadcast dims? I thought it would be useful for type inference, but they are not sufficient for that.
Well yes, I gave this a thought before and I think this is just for readability (explicit in IR) and inconvenience in type inference (avoid computing the diff |
Readability is something that should be handled by parser/printer and not take an excessive amount of effort. This is a compiler IR, not a human-writable language. The cost of this duplication is that every single modification of this operation will have to ensure the explicit attribute remains consistent with the result type. This cost largely outweighs the benefit. Similarly, computing |
b55cc87 to
434e74e
Compare
|
I agree with your reasoning about redundancy. However, I notice that reduction op has an explicit axis attribute (along |
Not particularly other than nobody commented about it in code review. There may be a point where (partial) type inference is sufficient and is enabled by it, but at that point we should make the attribute mutually exclusive with fully-specified types. |
76a3e0a to
013639b
Compare
Add BroadcastOp to the Wave dialect for broadcasting a tensor to a larger shape by replicating values along specified dimensions. The operation takes a source tensor and a broadcast_dims attribute specifying which dimensions are being added. Some design decisions: - broadcast_dims attribute explicitly specifies which dimensions are added (source_shape + broadcast_dims = result_shape). - BroadcastElementsPerThreadOpTrait: EPT propagation depends on broadcast dimension. When broadcasting along thread X, EPT is not propagated (NoChange) since source has no thread X and result EPT should come from downstream users (similar to how PyWave copies index from context). For other dims, identity propagation is used. - IdentityIndexExprsOpTrait: index expressions for shared dims propagate bidirectionally, broadcast dims are filled by backward propagation. - Custom type inference: backward propagation can infer source shape from result shape minus broadcast_dims. Signed-off-by: tyb0807 <[email protected]>
Signed-off-by: tyb0807 <[email protected]>
Add BroadcastOp to the Wave dialect for broadcasting a tensor to a larger
shape by replicating values along specified dimensions. The operation takes
a source tensor and a broadcast_dims attribute specifying which dimensions
are being added.
Some design decisions:
broadcast_dims attribute explicitly specifies which dimensions are added
(source_shape + broadcast_dims = result_shape).
BroadcastElementsPerThreadOpTrait: EPT propagation depends on broadcast
dimension. When broadcasting along thread X, EPT is not propagated
(NoChange) since source has no thread X and result EPT should come from
downstream users (similar to how PyWave copies index from context).
For other dims, identity propagation is used.
IdentityIndexExprsOpTrait: index expressions for shared dims propagate
bidirectionally, broadcast dims are filled by backward propagation.
Custom type inference: backward propagation can infer source shape from
result shape minus broadcast_dims.
Part of [water] Implement wave.broadcast #721