-
Notifications
You must be signed in to change notification settings - Fork 628
Description
I am running torch mlir to tosa mlir conversion pipeline (command below), but keep seeing torch.aten.mm as illegal fn. Can someone help me figure out a pass for tosa conversion correctly?
torch-mlir-opt --torch-function-to-torch-backend-pipeline --torch-backend-to-tosa-backend-pipeline torch.mlir -o tosa.mlir
torch.mlir:7:10: error: failed to legalize operation 'torch.aten.mm' that was explicitly marked illegal
%2 = torch.aten.mm %arg0, %1 : !torch.vtensor<[32,4096],f32>, !torch.vtensor<[4096,128256],bf16> -> !torch.vtensor<[32,128256],f32>
^
torch.mlir:7:10: note: see current operation: %8 = "torch.aten.mm"(%arg0, %7) : (!torch.vtensor<[32,4096],f32>, !torch.vtensor<[4096,128256],bf16>) -> !torch.vtensor<[32,128256],f32>
fx.export_and_import fails for this as well with tosa dialect yielding following error:
/torch_mlir/torch_mlir/compiler_utils.py", line 127, in run_pipeline_with_repro_report
raise TorchMlirCompilerError(trimmed_message) from None
torch_mlir.compiler_utils.TorchMlirCompilerError: Lowering Torch Backend IR -> TOSA Backend IR failed with the following diagnostics:
python exception: Failure while executing pass pipeline
For Torch-MLIR developers, the error can be reproduced with:
$ torch-mlir-opt -pass-pipeline='builtin.module(torch-backend-to-tosa-backend-pipeline)' /tmp/UnnammedModule.mlir