-
Notifications
You must be signed in to change notification settings - Fork 16
[LinalgToXeGPU] Remove redundant linalg.broadcasts #419
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: dchigarev <[email protected]>
Signed-off-by: dchigarev <[email protected]>
| // Checks whether the given linalgOp operand is produced by a | ||
| // `linalg::BroadcastOp` that can be replaced by a simple subview | ||
| // (for example broadcast: 7x128 -> 1x1x7x128) and ensures that | ||
| // the broadcast result is only used by linalgOp in question. | ||
| // | ||
| // If a valid `linalg::BroadcastOp` is found, the function removes it | ||
| // and returns the operand of the `linalg::BroadcastOp` as the new | ||
| // linalgOp operand. Otherwise returns the original operand. | ||
| static Value findAndReplaceBroadcast(linalg::LinalgOp linalgOp, | ||
| size_t operandIdx, | ||
| PatternRewriter &rewriter) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You may ask: why can't we just lower every such linalg.broadcast into something like memref.expand_shape via a separate pattern, instead of doing this "find a broadcast that produces an operand of a linalg-op that we're already lowering to xegpu" quest?
The problem is that the memref-to-spirv pass supports a very limited set of memref ops that can be lowered. It's basically only memref.subview that is supported and we can't expand memref shapes with it. So we can't just replace linalg.broadcast with memref.expand_shape since our pipeline shall fail then:
// --------- before LinalgToXeGPU
%inp = memref.alloc() : memref<7x128xf16>
%out = memref.alloc() : memref<1x7x128xf16>
linalg.broadcast ins(%inp) out(%out)
linalg.add ins(%out, ...)
// --------- after LinalgToXeGPU
// BroadcastToExpandShapePattern:
%inp = memref.alloc() : memref<7x128xf16>
%out = memref.expand_shape %inp : memref<7x128> to memref<1x7x128> // <-- this will crash our pipeline
// ElementWiseToXeGPUPattern:
%out_squeeze = memref.subview %out : memref<1x7x128> to memref<7x128>
%desc = xegpu.create_tensor_desc %out_squeeze
...And although a human eye can see here, that the memref.expand_shape + memref.subview can be eliminated, none of the upstream passes can do that. Even if the expand_shape-subview-merger pass existed, we still could not guarantee, that the memref.expand_shape is always followed by a rank-reducing memref.subview that it can be merged with. Example:
// --------- before LinalgToXeGPU
%inp = memref.alloc() : memref<7x128xf16>
%out = memref.alloc() : memref<1x7x128xf16>
linalg.broadcast ins(%inp) out(%out)
linalg.trickyOp ins(%out, ...)
// --------- after LinalgToXeGPU
// BroadcastToExpandShapePattern:
%inp = memref.alloc() : memref<7x128xf16>
%out = memref.expand_shape %inp : memref<7x128> to memref<1x7x128> // <-- this will crash our pipeline
// 'linalg.trickyOp' is not supported by LinalgToXeGPU pass
// no rank-reducing memref.subview to merge 'expand_shape' with
linalg.trickyOp ins(%out, ...)
...
// --------- after LinalgToLoops
// BroadcastToExpandShapePattern:
%inp = memref.alloc() : memref<7x128xf16>
%out = memref.expand_shape %inp : memref<7x128> to memref<1x7x128> // <-- this will crash our pipeline
for {
for {
for {
%outScalar = memref.load %out
arith.trickyOp %outScalar
...
}
}
}
...So the only option we're left with is to only "lower" linalg.broadcast when it produces an operand of a linalgOp that we're lowering to xegpu right now, and only do so by simply erasing broadcastOp and forwarding its input to the input of the linalgOp in question. Example:
// --------- before LinalgToXeGPU
%inp = memref.alloc() : memref<7x128xf16>
%out = memref.alloc() : memref<1x7x128xf16>
linalg.broadcast ins(%inp) out(%out)
linalg.add ins(%out, ...)
// --------- after LinalgToXeGPU
// ElementWiseToXeGPUPattern:
%inp = memref.alloc() : memref<7x128xf16>
%desc = xegpu.create_tensor_desc %inp
...|
|
||
| pm.addPass(createDecomposeTensorOperation()); | ||
| pm.addNestedPass<func::FuncOp>(createGpuTilingAndFusion()); | ||
| pm.addPass(createCanonicalizerPass()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should do the 'cleaning' right after the tiling. Otherwise the bufferization pass may produce memref.cast ops that can not be lowered by memref-to-spirv
kurapov-peter
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like we are just squeezing and expanding shapes back and forth. I wonder if we could just avoid creating the broadcast in the first place. That'll have to go into the original module creation and I see how it is tricky when we perform the conversion one operator at a time (e.g., given two elementwise ops, the first is converted to match the output shape with a broadcast; it only becomes clear it is redundant when we have the complete module).
Let's test this out, it's done high-enough so to not change the overall behavior too much.
Signed-off-by: dchigarev <[email protected]>
The PR implements logic that removes
linalg.broadcastops that do not perform any actual broadcasting. Example:Why not support all broadcast cases?
A proper lowering for broadcast can be tricky, since xegpu only supports 2D memrefs. Broadcast is always a shape-expanding operation, so there's always at least one operand that is not 2D.