Skip to content

torch.aten.upsample_nearest2d.vec cannot be legalized in TorchToLinalg #4327

@raayandhar

Description

@raayandhar

Running this reproducer:

func.func @new_small_repro (%24538: tensor<2x1280x32x32xf16>) -> tensor<2x1280x64x64xf16> {
    %float2.000000e00 = torch.constant.float 2.000000e+00
    %none = torch.constant.none
    %24539 = torch_c.from_builtin_tensor %24538 : tensor<2x1280x32x32xf16> -> !torch.vtensor<[2,1280,32,32],f16>
    %24540 = torch.prim.ListConstruct %float2.000000e00, %float2.000000e00 : (!torch.float, !torch.float) -> !torch.list<float>
    %24541 = torch.aten.upsample_nearest2d.vec %24539, %none, %24540 : !torch.vtensor<[2,1280,32,32],f16>, !torch.none, !torch.list<float> -> !torch.vtensor<[2,1280,64,64],f16>
    %24542 = torch_c.to_builtin_tensor %24541 : !torch.vtensor<[2,1280,64,64],f16> -> tensor<2x1280x64x64xf16>
func.return %24542: tensor<2x1280x64x64xf16>
}

Returns the error

new_small_repro.mlir:2:25: error: failed to legalize operation 'torch.constant.float'
    %float2.000000e00 = torch.constant.float 2.000000e+00
                        ^
new_small_repro.mlir:2:25: note: see current operation: %0 = "torch.constant.float"() <{value = 2.000000e+00 : f64}> : () -> !torch.float

however, running iree-opt --pass-pipeline=builtin.module(func.func(convert-torch-to-linalg)) --debug and looking through the logs (gist here), we find that there it is unable to legalize torch.aten.upsample_nearest2d.vec. We have ConvertAtenUpsampleNearest2dOp in TorchToLinalg/IndirectDataMovement.cpp - it seems like it does not recognize it can use this translation as it is not present in the logs (yet other functions successfully match, we never try to match ConvertAtenUpsampleNearest2dOp)

I'm hoping to get some insight as to why it's unable to handle this operation? Thanks!

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions