Skip to content

patch reduce_max and ConvertModel.forward retention of output activations#79

Open
sammlapp wants to merge 3 commits intoTalmaj:masterfrom
sammlapp:master
Open

patch reduce_max and ConvertModel.forward retention of output activations#79
sammlapp wants to merge 3 commits intoTalmaj:masterfrom
sammlapp:master

Conversation

@sammlapp
Copy link

@sammlapp sammlapp commented Jan 6, 2026

These were small patches that I needed to implement in order to convert an onnx model (link below) to PyTorch, as mentioned in #78
onnx model: https://huggingface.co/justinchuby/Perch-onnx/blob/main/perch_v2_no_dft.onnx

For what it's worth, I also had to modify the Pad and Reshape ConvertModel layers to avoid incorrect padding and reshaping behavior. I don't know if there is a generalizable fix to these issues.

# patches to fix ONNX → PyTorch conversion issues
class FixedONNXPad(nn.Module):  # patchs onnx2pytorch's pad.Pad
    def forward(self, input, pads=None, value=0):
        if pads is None:
            raise TypeError("pads must be provided")

        pads = pads.tolist() if torch.is_tensor(pads) else list(pads)

        ndim = input.ndim
        assert len(pads) == 2 * ndim, (pads, input.shape)

        before = pads[:ndim]
        after = pads[ndim:]

        # Convert ONNX → PyTorch pad order
        torch_pads = []
        for b, a in zip(reversed(before), reversed(after)):
            torch_pads.extend([b, a])

        return F.pad(input, torch_pads, value=value)


class FixedReshape(nn.Module):
    def forward(self, x, shape=None):
        B, H, Q, D = x.shape
        assert D % 4 == 0
        return x.view(B, H, Q, D // 4, 4)
    
onnx_model = onnx.load("./perch_v2_no_dft.onnx") 
pytorch_model = ConvertModel(onnx_model)
pytorch_model.Pad_pad_output_0 = FixedONNXPad()
pytorch_model._modules[
    "Reshape_jit(infer_fn)/MultiHeadClassifier/MultiHeadClassifier._call_model/heads_protopnet_logits/dot_general10_reshaped_0"
] = FixedReshape()

otherwise we'll get a key error below when collecting activations for outputs, because we've removed an activation that is supposed to be included in the output
otherwise we'll get a key error below when collecting activations for outputs, because we've removed an activation that is supposed to be included in the output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant