Skip to content

Commit f03a576

Browse files
authored
[TorchToTosa] Refactoring to separate construction of legal/illegal ops and conversion patterns. (llvm#3759)
This PR refactors TorchToTosa to separate the construction of legal/illegal ops and conversion patterns in their own functions: 1. populateTorchToTosaConversionLegalOps -- populate any ops that are legal after the conversion pass 2. populateTorchToTosaConversionIllegalOps -- populate any ops that are illegal after the conversion pass 3. populateTorchToTosaConversionPatterns -- populate the ops conversion patterns Currently the (il)legality of the ops that are (il)legal after the conversion pass runs is embedded within the conversion pattern. Our end goal is to write a new pass pipeline that converts `torch` ops to a mix of `tosa`, `linalg`, `tensor`, etc dialect ops. The reason we want to also emit `tosa` ops (instead of using the existing `TorchToLinalg` to emit `linalg`+`tensor`+...) is because some operations like `conv2d` encodes the padding behavior in the op in `tosa` unlike the `linalg` version -- this helps in lowering the `tosa.conv2d` to a custom implementation that does padding on the fly. To implement this new pipeline we need to be able to separate out the illegal `tosa` ops from the conversion pattern itself. Otherwise we will hit an issue for ops like `AtenMaxDimOp` which can be lowered to both `tosa` and `linalg + others` dialects. Not all `AtenMaxDimOp` can be lowered successfully to `tosa` as the implementation uses `tosa.reshape` which cannot handle multiple dynamic dimensions but the `TorchToLinalg` lowering can handle it. In the current behavior the pipeline will stop as soon as the existing `TorchToTosa` conversion runs as `AtenMaxDimOp` will be marked as an illegal op. Essentially we want to be able to control what the legality of the ops should be independent of the conversion pattern. This is also inline with the conversion patterns in the llvm-mlir repo such as https://github.com/llvm/llvm-project/blob/000e790be35b77a01872851646d54432a203542c/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp#L718 "THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY."
1 parent 5a5cc6b commit f03a576

File tree

2 files changed

+249
-222
lines changed

2 files changed

+249
-222
lines changed

include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,25 @@
1212

1313
#include "mlir/Dialect/Func/IR/FuncOps.h"
1414
#include "mlir/Pass/Pass.h"
15+
#include "mlir/Transforms/DialectConversion.h"
16+
1517
#include <memory>
1618

1719
namespace mlir {
1820
namespace torch {
21+
22+
/// Collect a set of legal/illegal ops for converting Torch operations to Tosa
23+
/// dialect.
24+
void populateTorchToTosaConversionLegalOps(ConversionTarget &target);
25+
26+
/// Collect a set of patterns to convert Torch operations to Tosa dialect +
27+
/// return the set of illegalOps
28+
std::set<StringRef>
29+
populateTorchToTosaConversionPatternsAndIllegalOps(TypeConverter &typeConverter,
30+
RewritePatternSet &patterns);
31+
1932
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToTosaPass();
20-
}
33+
} // namespace torch
2134
} // namespace mlir
2235

2336
#endif // TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H

0 commit comments

Comments
 (0)