Skip to content

Conversation

@tyb0807
Copy link
Contributor

@tyb0807 tyb0807 commented Jan 26, 2026

Add BroadcastOp to the Wave dialect for broadcasting a tensor to a larger
shape by replicating values along specified dimensions. The operation takes
a source tensor and a broadcast_dims attribute specifying which dimensions
are being added.

Some design decisions:

  • broadcast_dims attribute explicitly specifies which dimensions are added
    (source_shape + broadcast_dims = result_shape).

  • BroadcastElementsPerThreadOpTrait: EPT propagation depends on broadcast
    dimension. When broadcasting along thread X, EPT is not propagated
    (NoChange) since source has no thread X and result EPT should come from
    downstream users (similar to how PyWave copies index from context).
    For other dims, identity propagation is used.

  • IdentityIndexExprsOpTrait: index expressions for shared dims propagate
    bidirectionally, broadcast dims are filled by backward propagation.

  • Custom type inference: backward propagation can infer source shape from
    result shape minus broadcast_dims.

    Part of [water] Implement wave.broadcast #721

Copy link
Contributor

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need the explicit broadcast dims? I thought it would be useful for type inference, but they are not sufficient for that.

@tyb0807
Copy link
Contributor Author

tyb0807 commented Jan 26, 2026

Why do we need the explicit broadcast dims? I thought it would be useful for type inference, but they are not sufficient for that.

Well yes, I gave this a thought before and I think this is just for readability (explicit in IR) and inconvenience in type inference (avoid computing the diff broadcast_dims = result_shape - source_shape).

@ftynse
Copy link
Contributor

ftynse commented Jan 28, 2026

Well yes, I gave this a thought before and I think this is just for readability (explicit in IR) and inconvenience in type inference (avoid computing the diff broadcast_dims = result_shape - source_shape).

Readability is something that should be handled by parser/printer and not take an excessive amount of effort. This is a compiler IR, not a human-writable language. The cost of this duplication is that every single modification of this operation will have to ensure the explicit attribute remains consistent with the result type. This cost largely outweighs the benefit. Similarly, computing broadcast_dims = result_shape - source_shape is cheap and straightforward, and may not be practically slower than indirecting through context-owned pointers, which the attributes are under the hood. However, we store an extra entry both on the operation itself, and in the context, the latter is never cleaned up. There are significant costs in having this duplication and virtually no benefits. Please remove. It is almost always a bad idea to duplicate things in the IR.

@tgymnich tgymnich force-pushed the users/ftynse/reductions branch from b55cc87 to 434e74e Compare January 28, 2026 12:09
@tyb0807
Copy link
Contributor Author

tyb0807 commented Jan 28, 2026

I agree with your reasoning about redundancy. However, I notice that reduction op has an explicit axis attribute (along @M) which is equally computable from source_shape - result_shape. Is there a specific reason that's acceptable but broadcast_dims isn't? Maybe I'm missing something?

@ftynse
Copy link
Contributor

ftynse commented Jan 28, 2026

Is there a specific reason that's acceptable but broadcast_dims isn't? Maybe I'm missing something?

Not particularly other than nobody commented about it in code review. There may be a point where (partial) type inference is sufficient and is enabled by it, but at that point we should make the attribute mutually exclusive with fully-specified types.

@tyb0807 tyb0807 force-pushed the broadcast branch 8 times, most recently from 76a3e0a to 013639b Compare January 29, 2026 13:25
Add BroadcastOp to the Wave dialect for broadcasting a tensor to a larger
shape by replicating values along specified dimensions. The operation takes
a source tensor and a broadcast_dims attribute specifying which dimensions
are being added.

Some design decisions:
- broadcast_dims attribute explicitly specifies which dimensions are added
  (source_shape + broadcast_dims = result_shape).
- BroadcastElementsPerThreadOpTrait: EPT propagation depends on broadcast
  dimension. When broadcasting along thread X, EPT is not propagated
  (NoChange) since source has no thread X and result EPT should come from
  downstream users (similar to how PyWave copies index from context).
  For other dims, identity propagation is used.
- IdentityIndexExprsOpTrait: index expressions for shared dims propagate
  bidirectionally, broadcast dims are filled by backward propagation.
- Custom type inference: backward propagation can infer source shape from
  result shape minus broadcast_dims.

Signed-off-by: tyb0807 <[email protected]>
Signed-off-by: tyb0807 <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants