-
Notifications
You must be signed in to change notification settings - Fork 25
[water] add support for reduction operations #741
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@tyb0807 maybe for once you can actually review the PR I send you for reviews... You were the last one to touch elements-per-threads and I don't know if my reasoning here is correct. |
31a5527 to
06f8818
Compare
06f8818 to
a3cca40
Compare
| return wave::detail::identityElementsPerThreadPropagate( | ||
| operandElements, resultElements, "operands", "results", errs); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, we (PyWave) currently only supports/emits reductions along fastest dimension. So basically we only need to care about 2 cases, whether the reduction dim is along thread X or not. IMHO, this only affects whether we need to shuffles across threads, not what the final EPT should be, right?
If we enforce "reduction must be along fastest dimension" in the verifier, then we should always set result EPT = 1, as reduction produces a scalar, no need for the if (init.threadXDimension == axis) distinction.
Is my understanding correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the fastest dimension? How do we know which one it is? We can be in the situation where all EPT are 1 already.
My initial logic was that, if we are reducing along the lane-mapped dimension, we will go from however many EPT we had to 1. If we are reducing along any other dimension, we should already have EPT 1. So we need to distinguish in at least in the error message.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I now suppose that if EPT corresponds to the symbol we reduce along, EPT should become 1 in the result (we will first reduce within each thread, then we will reduce across threads via shuffles so that each thread has 1 element along that dimension and each thread has a copy of the same value). If it corresponds to something else, these are individual rows and EPT should remain. Does this sound reasonable?
| if (init.threadXDimension == axis) { | ||
| // Reducing along the thread X, so mapped to lanes, means we will have one | ||
| // element per thread. | ||
| // TODO: same as above. | ||
| wave::ElementsPerThreadLatticeValue expectedOperand(1); | ||
| return wave::detail::checkAndPropagateElementsPerThreadFromConstant( | ||
| expectedOperand, {}, operandElements.slice(initOperandNum, 1), | ||
| "reduction along thread X dimension", "", "operands", errs); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Likewise I don't think we need to distinguish between thread X vs. non thread X, EPT will be 1 anyway, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto.
| // TODO: do we need to have elements per thread attribute here so we can set | ||
| // it as lattice value for input? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can't think of a way to infer input EPT from result EPT. I guess input EPT must come from the operation producing the input?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or specified as an attribute on this operation, the same we do for read/write, which is the purpose of this todo.
| // input type to have one more dimension that precisely matches the reduction | ||
| // axis. | ||
| template <typename OpTy> | ||
| static LogicalResult verifyReductionOperationTypes(OpTy op) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe a NotImplemented verification for the case where reductions are not along the fastest dimension?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as elsewhere, how do we know which is the fastest dimension?
a3cca40 to
090563f
Compare
martin-luecke
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good overall with a few nits to address.
One question I have is whether we allow repeated symbols in WaveTensorType, since that breaks some of the logic here. AFAICS, we currently don't guard against this and should probably just add it to the verifier of WaveTensorType.
However, I remember one case where @tgymnich needed one dimension to be a fraction of the size of another dim, but I think that was eventually not handled with similar symbols for the dimensions.
|
|
||
| // ----- | ||
|
|
||
| func.func @rank_mismatch(%input: !wave.tensor<[@N, @M] of f32>, %init: !wave.tensor<[@N] of f32>) -> !wave.tensor<[@N] of f32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| func.func @rank_mismatch(%input: !wave.tensor<[@N, @M] of f32>, %init: !wave.tensor<[@N] of f32>) -> !wave.tensor<[@N] of f32> { | |
| func.func @rank_mismatch(%input: !wave.tensor<[@N, @M] of f32>, %init: !wave.tensor<[@N, @M] of f32>) -> !wave.tensor<[@N] of f32> { |
| // CHECK: wave.register {{.*}} : vector<8xf32> | ||
| %init = wave.register %c0 : !wave.tensor<[@N] of f32, <register>> | ||
|
|
||
| // CHECK: wave.sum {{.*}} : (vector<8xf32>, vector<1xf32>) -> vector<1xf32> | ||
| %sum = wave.sum %reg init(%init) along @M : (!wave.tensor<[@M, @N] of f32, <register>>, !wave.tensor<[@N] of f32, <register>>) -> !wave.tensor<[@N] of f32, <register>> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Init is once checked to be of type vector<8xf32> and then vector<1xf32>. As we reduce along threadXdim, it should be the latter, right?
| // Reducing along the thread X, so mapped to lanes, means we will have one | ||
| // element per thread. | ||
| // TODO: not sure about that, it feels more like one element in general, not | ||
| // per thread. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think one element per thread would be consistent with the Python implementation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This sounds a bit strange with respect to the reduction semantics. Is this operation equivalent to MPI_AllReduce where all concurrent threads have the same value as a result (as opposed to only the leading thread)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We concluded that the semantics is allreduce.
| // TODO: do we need to have elements per thread attribute here so we can set | ||
| // it as lattice value for input? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don’t think we can infer the input elements_per_thread from the reduction alone.
Only result and accumulator are constrained here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, that's the point of the todo. We may want to explicitly specify it on the operation.
090563f to
b55cc87
Compare
| // Expect PropagateElementsPerThread pass to have run, converting | ||
| // WaveTensorType results to VectorType. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be an assert?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there is no normal form for this yet. right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the MemoryOnly form does it. It only allows tensors in global or shared memory and disallows tensors in registers, which means the latter should have been converted to vectors.
| Value subgroup_reduce = gpu::SubgroupReduceOp::create( | ||
| rewriter, loc, thread_reduce, gpu::AllReduceOperation::ADD, false); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to find a way to diagnose or at least highlight that the reduction is per-subgroup...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we will need to mirror the block boolean parameter from PyWave
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant that we should try inferring whether we run in one subgroup (wave) or not from constraints, and complain if there are reductions present when running in more than one subgroup. When we add the block boolean, we will need a similar check to see whether we run in a single block.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
Introduce an ODS class to add support for reduction operations and instantiate it for sum and max_element reducitons. Note the name of the latter since we may eventually want to have a binary `max` operation that would clash, unlike the add/sum dichotomy. Implement type inference and elements per thread propagation using traits for reductions. This required modifying elements per thread analysis to take a common object featuring thread X mapping information derived from constraints so we don't have to look up the IR every time during a per-operation propagation call. This in turn caused some churn in tests that now require at least an empty list of constraints. This also generalized the trait requiring sideways propagation between results since similar behavior is necessary for reduction operations. Signed-off-by: Alex Zinenko <[email protected]>
Signed-off-by: Tim Gymnich <[email protected]>
Signed-off-by: Tim Gymnich <[email protected]>
Signed-off-by: Tim Gymnich <[email protected]>
Signed-off-by: Tim Gymnich <[email protected]>
Signed-off-by: Alex Zinenko <[email protected]>
Signed-off-by: Alex Zinenko <[email protected]>
Signed-off-by: Tim Gymnich <[email protected]>
b55cc87 to
434e74e
Compare
Signed-off-by: Tim Gymnich <[email protected]>
Signed-off-by: tyb0807 <[email protected]>
…nstraints Signed-off-by: Tim Gymnich <[email protected]>
* the reduction axis attribute is optional and only allowed then the operand/result types are underspecified, to avoid duplication * we only reduce along the trailing dimension * instead of a boolean flag, we have an enum value specifically indicating whether the reduction is warp or block instead of implicitly defaultnig to warp Signed-off-by: Alex Zinenko <[email protected]>
Introduce an ODS class to add support for reduction operations and instantiate it for sum and max_element reducitons. Note the name of the latter since we may eventually want to have a binary
maxoperation that would clash, unlike the add/sum dichotomy.Implement type inference and elements per thread propagation using traits for reductions. This required modifying elements per thread analysis to take a common object featuring thread X mapping information derived from constraints so we don't have to look up the IR every time during a per-operation propagation call. This in turn caused some churn in tests that now require at least an empty list of constraints.
This also generalized the trait requiring sideways propagation between results since similar behavior is necessary for reduction operations.