Skip to content

Conversation

@ftynse
Copy link
Contributor

@ftynse ftynse commented Jan 15, 2026

Introduce an ODS class to add support for reduction operations and instantiate it for sum and max_element reducitons. Note the name of the latter since we may eventually want to have a binary max operation that would clash, unlike the add/sum dichotomy.

Implement type inference and elements per thread propagation using traits for reductions. This required modifying elements per thread analysis to take a common object featuring thread X mapping information derived from constraints so we don't have to look up the IR every time during a per-operation propagation call. This in turn caused some churn in tests that now require at least an empty list of constraints.

This also generalized the trait requiring sideways propagation between results since similar behavior is necessary for reduction operations.

@ftynse ftynse requested review from tgymnich and tyb0807 January 15, 2026 20:45
@ftynse
Copy link
Contributor Author

ftynse commented Jan 15, 2026

@tyb0807 maybe for once you can actually review the PR I send you for reviews... You were the last one to touch elements-per-threads and I don't know if my reasoning here is correct.

@ftynse ftynse force-pushed the users/ftynse/reductions branch from 31a5527 to 06f8818 Compare January 15, 2026 21:10
@tgymnich tgymnich self-assigned this Jan 22, 2026
@tgymnich tgymnich linked an issue Jan 22, 2026 that may be closed by this pull request
@tgymnich tgymnich force-pushed the users/ftynse/reductions branch from 06f8818 to a3cca40 Compare January 22, 2026 16:13
@tgymnich tgymnich linked an issue Jan 22, 2026 that may be closed by this pull request
Comment on lines +370 to +376
return wave::detail::identityElementsPerThreadPropagate(
operandElements, resultElements, "operands", "results", errs);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, we (PyWave) currently only supports/emits reductions along fastest dimension. So basically we only need to care about 2 cases, whether the reduction dim is along thread X or not. IMHO, this only affects whether we need to shuffles across threads, not what the final EPT should be, right?

If we enforce "reduction must be along fastest dimension" in the verifier, then we should always set result EPT = 1, as reduction produces a scalar, no need for the if (init.threadXDimension == axis) distinction.

Is my understanding correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the fastest dimension? How do we know which one it is? We can be in the situation where all EPT are 1 already.

My initial logic was that, if we are reducing along the lane-mapped dimension, we will go from however many EPT we had to 1. If we are reducing along any other dimension, we should already have EPT 1. So we need to distinguish in at least in the error message.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I now suppose that if EPT corresponds to the symbol we reduce along, EPT should become 1 in the result (we will first reduce within each thread, then we will reduce across threads via shuffles so that each thread has 1 element along that dimension and each thread has a copy of the same value). If it corresponds to something else, these are individual rows and EPT should remain. Does this sound reasonable?

Comment on lines +380 to +391
if (init.threadXDimension == axis) {
// Reducing along the thread X, so mapped to lanes, means we will have one
// element per thread.
// TODO: same as above.
wave::ElementsPerThreadLatticeValue expectedOperand(1);
return wave::detail::checkAndPropagateElementsPerThreadFromConstant(
expectedOperand, {}, operandElements.slice(initOperandNum, 1),
"reduction along thread X dimension", "", "operands", errs);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise I don't think we need to distinguish between thread X vs. non thread X, EPT will be 1 anyway, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto.

Comment on lines +389 to +394
// TODO: do we need to have elements per thread attribute here so we can set
// it as lattice value for input?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't think of a way to infer input EPT from result EPT. I guess input EPT must come from the operation producing the input?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or specified as an attribute on this operation, the same we do for read/write, which is the purpose of this todo.

// input type to have one more dimension that precisely matches the reduction
// axis.
template <typename OpTy>
static LogicalResult verifyReductionOperationTypes(OpTy op) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a NotImplemented verification for the case where reductions are not along the fastest dimension?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as elsewhere, how do we know which is the fastest dimension?

@tgymnich tgymnich force-pushed the users/ftynse/reductions branch from a3cca40 to 090563f Compare January 23, 2026 13:47
Copy link
Contributor

@martin-luecke martin-luecke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall with a few nits to address.

One question I have is whether we allow repeated symbols in WaveTensorType, since that breaks some of the logic here. AFAICS, we currently don't guard against this and should probably just add it to the verifier of WaveTensorType.
However, I remember one case where @tgymnich needed one dimension to be a fraction of the size of another dim, but I think that was eventually not handled with similar symbols for the dimensions.


// -----

func.func @rank_mismatch(%input: !wave.tensor<[@N, @M] of f32>, %init: !wave.tensor<[@N] of f32>) -> !wave.tensor<[@N] of f32> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
func.func @rank_mismatch(%input: !wave.tensor<[@N, @M] of f32>, %init: !wave.tensor<[@N] of f32>) -> !wave.tensor<[@N] of f32> {
func.func @rank_mismatch(%input: !wave.tensor<[@N, @M] of f32>, %init: !wave.tensor<[@N, @M] of f32>) -> !wave.tensor<[@N] of f32> {

Comment on lines 420 to 424
// CHECK: wave.register {{.*}} : vector<8xf32>
%init = wave.register %c0 : !wave.tensor<[@N] of f32, <register>>

// CHECK: wave.sum {{.*}} : (vector<8xf32>, vector<1xf32>) -> vector<1xf32>
%sum = wave.sum %reg init(%init) along @M : (!wave.tensor<[@M, @N] of f32, <register>>, !wave.tensor<[@N] of f32, <register>>) -> !wave.tensor<[@N] of f32, <register>>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Init is once checked to be of type vector<8xf32> and then vector<1xf32>. As we reduce along threadXdim, it should be the latter, right?

Comment on lines +361 to +368
// Reducing along the thread X, so mapped to lanes, means we will have one
// element per thread.
// TODO: not sure about that, it feels more like one element in general, not
// per thread.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think one element per thread would be consistent with the Python implementation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sounds a bit strange with respect to the reduction semantics. Is this operation equivalent to MPI_AllReduce where all concurrent threads have the same value as a result (as opposed to only the leading thread)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We concluded that the semantics is allreduce.

Comment on lines +389 to +394
// TODO: do we need to have elements per thread attribute here so we can set
// it as lattice value for input?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don’t think we can infer the input elements_per_thread from the reduction alone.
Only result and accumulator are constrained here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, that's the point of the todo. We may want to explicitly specify it on the operation.

@ftynse ftynse force-pushed the users/ftynse/reductions branch from 090563f to b55cc87 Compare January 26, 2026 16:23
Comment on lines +839 to +840
// Expect PropagateElementsPerThread pass to have run, converting
// WaveTensorType results to VectorType.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be an assert?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is no normal form for this yet. right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the MemoryOnly form does it. It only allows tensors in global or shared memory and disallows tensors in registers, which means the latter should have been converted to vectors.

Comment on lines 850 to 851
Value subgroup_reduce = gpu::SubgroupReduceOp::create(
rewriter, loc, thread_reduce, gpu::AllReduceOperation::ADD, false);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to find a way to diagnose or at least highlight that the reduction is per-subgroup...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we will need to mirror the block boolean parameter from PyWave

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant that we should try inferring whether we run in one subgroup (wave) or not from constraints, and complain if there are reductions present when running in more than one subgroup. When we add the block boolean, we will need a similar check to see whether we run in a single block.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

ftynse and others added 8 commits January 28, 2026 11:59
Introduce an ODS class to add support for reduction operations and instantiate
it for sum and max_element reducitons. Note the name of the latter since we may
eventually want to have a binary `max` operation that would clash, unlike the
add/sum dichotomy.

Implement type inference and elements per thread propagation using traits for
reductions. This required modifying elements per thread analysis to take a
common object featuring thread X mapping information derived from constraints
so we don't have to look up the IR every time during a per-operation
propagation call. This in turn caused some churn in tests that now require at
least an empty list of constraints.

This also generalized the trait requiring sideways propagation between results
since similar behavior is necessary for reduction operations.

Signed-off-by: Alex Zinenko <[email protected]>
Signed-off-by: Tim Gymnich <[email protected]>
Signed-off-by: Tim Gymnich <[email protected]>
Signed-off-by: Alex Zinenko <[email protected]>
Signed-off-by: Tim Gymnich <[email protected]>
@tgymnich tgymnich force-pushed the users/ftynse/reductions branch from b55cc87 to 434e74e Compare January 28, 2026 12:09
tgymnich and others added 4 commits January 29, 2026 11:20
Signed-off-by: Tim Gymnich <[email protected]>
* the reduction axis attribute is optional and only allowed then the
  operand/result types are underspecified, to avoid duplication

* we only reduce along the trailing dimension

* instead of a boolean flag, we have an enum value specifically indicating
  whether the reduction is warp or block instead of implicitly defaultnig to
  warp

Signed-off-by: Alex Zinenko <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[water] Implement wave.maximum

5 participants