Skip to content

Converting captum functions to ONNX #1778

@Rayndell

Description

@Rayndell

I am trying to export and load an ONNX model which generates captum heatmaps as outputs and thus encapsulates captum functions. My export code looks like this:

class ClassificationModel(torch.nn.Module):
    def __init__(self, base_model, device, resize_shape=None, heatmap_generator=None):
        super(ClassificationModel, self).__init__()
        self.base_model = base_model
        self.device = device
        self.transform = get_preprocessing(resize_shape=resize_shape, data_augmentation=False)
        self.heatmap_generator = heatmap_generator
        
    def forward(self, inputs):
        # Apply transforms
        inputs = self.transform(inputs)
        
        # Apply model to transformed input
        outputs = self.base_model(inputs)
        
        # Generate heatmaps
        heatmaps = torch.zeros(inputs.shape)
        if self.heatmap_generator:
            heatmaps = self.heatmap_generator.generate(inputs, outputs.shape[1], self.device)
            # Resize heatmap if needed
            img_height, img_width = inputs.shape[-2:]
            resize = get_resize((img_height, img_width))
            heatmaps = resize(heatmaps)
        
        return outputs, heatmaps
    
    def export_to_onnx(self, path):
        self.eval()
        input_names = ["input"]
        output_names = ["output"]
        
        dynamic_axes = {
            "input": {0: "batch_size", 2: "height", 3: "width"},
        }
        
        if self.heatmap_generator:
            output_names.append("heatmaps")
            dynamic_axes["heatmaps"] = {0: "batch_size", 2: "height", 3: "width"}
           
        self.to(self.device)
        input_example = torch.randint(0, 255, (1, 3, 512, 512), dtype=torch.uint8).to(self.device)
        torch.onnx.export(self, input_example, path, export_params=True, opset_version=16,
            input_names=input_names, output_names=output_names,
            dynamic_axes=dynamic_axes, dynamo=False
        )

# Reading the configuration file
config = get_config(args.config_file)
resize_shape = (config.resize_height, config.resize_width) if config.resize else None
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
# Load the given checkpoint and retrieve the class names
checkpoint = torch.load(args.checkpoint_path, weights_only=True)
class_names = ["CD","CV","Pachy","SD","Hex","Max","Min"]
if "class_names" in checkpoint:
    class_names = checkpoint['class_names']
    
# Rebuild the model
model = build_model(config, len(class_names))
        
# Load the model weights
if "model_state_dict" in checkpoint:
    model.load_state_dict(checkpoint['model_state_dict'], strict=False)
else:
    model.load_state_dict(checkpoint, strict=False)
    
# Create a heatmap generator
heatmap_generator = HeatmapGenerator(model, heatmap_type="IntegratedGradients")
    
export_to_onnx(model, args.output, device, resize_shape, heatmap_generator)

Here is the captum code I try to encapsulate:

import torch
from captum.attr import IntegratedGradients

def cumulative_sum_threshold_tensor(values_tensor, percentile):
    sorted_vals, _ = torch.sort(torch.flatten(values_tensor))
    cum_sums = torch.cumsum(sorted_vals, 0)
    threshold_id = (cum_sums >= cum_sums[-1] * 0.01 * percentile).nonzero()
    return sorted_vals[threshold_id[0]]

def normalize_scale_tensor(attr_tensor, scale_factor):
    attr_norm = attr_tensor / scale_factor
    return torch.clamp(attr_norm, -1, 1)

def normalize_attr_tensor(attr_tensor, sign, outlier_perc, device):
    if sign == "positive":
        attr_tensor = torch.maximum(attr_tensor, torch.zeros(attr_tensor.shape).to(device))
    elif sign == "negative":
        attr_tensor = torch.minimum(attr_tensor, torch.zeros(attr_tensor.shape).to(device))
    elif sign == "absolute":
        attr_tensor = torch.abs(attr_tensor)
    
    threshold = cumulative_sum_threshold_tensor(torch.abs(attr_tensor), 100.0 - outlier_perc)
    
    if sign == "negative":
        threshold = -threshold
    
    return normalize_scale_tensor(attr_tensor, threshold)

class HeatmapGenerator:
    def __init__(self, model, heatmap_type=None):
        self.model = model
        self.attribute_generator = None
        if heatmap_type == "IntegratedGradients":
            self.attribute_generator = IntegratedGradients(model, multiply_by_inputs=False)
    
    def generate(self, inputs, num_targets, device):
        all_heatmaps = []
        # Generate one heatmap per class
        for t in range(num_targets):
            heatmaps = []
            
            # Generate attributes
            attr = self.attribute_generator.attribute(inputs, target=t, n_steps=10)
            
            for i in range(attr.shape[0]):
                # Sum on the channels dimension to get a single gray image
                attr_sum = torch.sum(attr[i], 0)
                # Normalize image
                attr_norm = normalize_attr_tensor(attr_sum, "positive", 1, device)
                heatmaps.append(attr_norm)
                
            all_heatmaps.append(torch.stack(heatmaps))
            
        return torch.stack(all_heatmaps, dim=1)

The heatmaps can be successfully retrieved from the exported ONNX model during inference. But they are always the same whatever input image is provided. From what I understood there might be some parameters converted to constants during the export, so that the resulting heatmaps are constant and inferred from the input_example provided to the torch.onnx.export function. Here is the output during export:

C:\Users\aurelien.ADCIS-Caen\AppData\Local\Programs\Python\Python312\Lib\site-packages\captum_utils\gradient.py:133: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
assert outputs[0].numel() == 1, (
C:\Users\aurelien.ADCIS-Caen\AppData\Local\Programs\Python\Python312\Lib\site-packages\captum\attr_core\integrated_gradients.py:379: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.

  • torch.tensor(step_sizes).float().view(n_steps, 1).to(grad.device)

Does anyone have experienced the same issue, or have tried to export captum functions into ONNX? Is it even possible? What can be a workaround to this?

Thank you very much.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions