Skip to content

CROWN compute_bounds error on matmul #94

@elfman2

Description

@elfman2

Describe the bug
The call to bm.forward and bm.compute_bounds(method='IBP') are OK
but bm.compute_bounds(method='CROWN') fails

To Reproduce

import os
os.environ['AUTOLIRPA_DEBUG']='1'
os.environ['AUTOLIRPA_DEBUG_NAMES']='1'
import torch
from torch import nn as nn
from auto_LiRPA import BoundedModule, BoundedTensor, register_custom_op
from auto_LiRPA.operators import BoundAdd, BoundRelu, Bound
from auto_LiRPA.perturbations import PerturbationLpNorm
from auto_LiRPA.operators.linear import BoundMatMul
import numpy as np
def boundTensor(size,mini=0., maxi=1.):
    return BoundedTensor(torch.empty(size), 
                         PerturbationLpNorm(norm = float("inf"), 
                                            x_L=torch.full(size,mini), 
                                            x_U=torch.full(size,maxi)))
class Test(nn.Module):
    def __init__(self):
        super(Test, self).__init__()
    def forward(self, x:torch.Tensor):
        return torch.matmul(torch.ones(3),x)
bm = BoundedModule(Test(),torch.empty((3,2)),verbose=True)
bm.forward(torch.ones((3,2)))
bm.compute_bounds(x=(boundTensor((3,2))), method='IBP')
bm.compute_bounds(x=(boundTensor((3,2))), method='CROWN')

System configuration:

  • OS: Red Hat Enterprise Linux 9.5 (Plow)
  • Python version: Python 3.11
  • Pytorch Version: torch 2.4.1
  • auto_LiRPA-0.5.0-py311-none-any.whl
  • Hardware: x86_64
  • Have you tried to reproduce the problem in a cleanly created conda/virtualenv environment using official installation instructions and the latest code on the main branch?: Yes

Screenshots

INFO     15:18:39     Converting the model...
DEBUG    15:18:40     Graph before ONNX convertion:
DEBUG    15:18:40     graph(%0 : Float(3, 2, strides=[2, 1], requires_grad=0, device=cpu)):
  %1 : int = prim::Constant[value=3]() # /tmp/ipykernel_1909/1641900459.py:20:0
  %2 : int[] = prim::ListConstruct(%1)
  %3 : NoneType = prim::Constant()
  %4 : NoneType = prim::Constant()
  %5 : Device = prim::Constant[value="cpu"]() # /tmp/ipykernel_1909/1641900459.py:20:0
  %6 : bool = prim::Constant[value=0]() # /tmp/ipykernel_1909/1641900459.py:20:0
  %7 : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::ones(%2, %3, %4, %5, %6) # /tmp/ipykernel_1909/1641900459.py:20:0
  %8 : Float(2, strides=[1], requires_grad=0, device=cpu) = aten::matmul(%7, %0) # /tmp/ipykernel_1909/1641900459.py:20:0
  return (%8)

DEBUG    15:18:40     trace_graph: graph(%0 : Float(*, 2, strides=[2, 1], requires_grad=0, device=cpu)):
  %1 : Float(3, strides=[1], requires_grad=0, device=cpu) = onnx::Constant[value= 1  1  1 [ CPUFloatType{3} ]]() # /tmp/ipykernel_1909/1641900459.py:20:0
  %2 : Float(2, strides=[1], requires_grad=0, device=cpu) = onnx::MatMul(%1, %0) # /tmp/ipykernel_1909/1641900459.py:20:0
  return (%2)

DEBUG    15:18:40     ONNX graph:
DEBUG    15:18:40     graph(%0_onnx::MatMul_1641900459_20 : Float(*, 2, strides=[2, 1], requires_grad=0, device=cpu)):
  %1_onnx::Constant_1641900459_20 : Float(3, strides=[1], requires_grad=0, device=cpu) = onnx::Constant[value= 1  1  1 [ CPUFloatType{3} ]]() # /tmp/ipykernel_1909/1641900459.py:20:0
  %2_onnx::MatMul_1641900459_20 : Float(2, strides=[1], requires_grad=0, device=cpu) = onnx::MatMul(%1_onnx::Constant_1641900459_20, %0_onnx::MatMul_1641900459_20) # /tmp/ipykernel_1909/1641900459.py:20:0
  return (%2_onnx::MatMul_1641900459_20)

INFO     15:18:40     Model converted to support bounds
DEBUG    15:18:40     Compute bounds with IBP
DEBUG    15:18:40     Final node BoundMatMul(/2_onnx::MatMul_1641900459_20)
DEBUG    15:18:40     IBP for BoundMatMul(name=/2_onnx::MatMul_1641900459_20, inputs=[/1_onnx::Constant_1641900459_20, /0_onnx::MatMul_1641900459_20], perturbed=True)
DEBUG    15:18:40     IBP for BoundConstant(name=/1_onnx::Constant_1641900459_20, inputs=[], perturbed=False)
DEBUG    15:18:40     Compute bounds with CROWN
DEBUG    15:18:40     Final node BoundMatMul(/2_onnx::MatMul_1641900459_20)
DEBUG    15:18:40     Getting the bounds of BoundConstant(name=/1_onnx::Constant_1641900459_20, inputs=[], perturbed=False)
DEBUG    15:18:40     Bound backward from BoundMatMul(/2_onnx::MatMul_1641900459_20) to bound BoundMatMul(/2_onnx::MatMul_1641900459_20)
DEBUG    15:18:40       C: shape torch.Size([3, 1, 1]), type <class 'torch.Tensor'>
DEBUG    15:18:40       Bound backward to BoundMatMul(name=/2_onnx::MatMul_1641900459_20, inputs=[/1_onnx::Constant_1641900459_20, /0_onnx::MatMul_1641900459_20], perturbed=True) (out shape torch.Size([2]))
DEBUG    15:18:40         lA type <class 'torch.Tensor'> shape [1, 3]
DEBUG    15:18:40         uA type <class 'torch.Tensor'> shape [1, 3]
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[1], line 24
     22 bm.forward(torch.ones((3,2)))
     23 bm.compute_bounds(x=(boundTensor((3,2))), method='IBP')
---> 24 bm.compute_bounds(x=(boundTensor((3,2))), method='CROWN')

File /opt/app-root/lib64/python3.11/site-packages/auto_LiRPA/bound_general.py:1402, in BoundedModule.compute_bounds(self, x, aux, C, method, IBP, forward, bound_lower, bound_upper, reuse_ibp, reuse_alpha, return_A, needed_A_dict, final_node_name, average_A, interm_bounds, reference_bounds, intermediate_constr, alpha_idx, aux_reference_bounds, need_A_only, cutter, decision_thresh, update_mask, ibp_nodes, cache_bounds)
   1399     elif bound_upper:
   1400         return ret2  # ret2[0] is None.
-> 1402 return self._compute_bounds_main(C=C,
   1403                                  method=method,
   1404                                  IBP=IBP,
   1405                                  bound_lower=bound_lower,
   1406                                  bound_upper=bound_upper,
   1407                                  reuse_ibp=reuse_ibp,
   1408                                  reuse_alpha=reuse_alpha,
   1409                                  average_A=average_A,
   1410                                  alpha_idx=alpha_idx,
   1411                                  need_A_only=need_A_only,
   1412                                  update_mask=update_mask)

File /opt/app-root/lib64/python3.11/site-packages/auto_LiRPA/bound_general.py:1507, in BoundedModule._compute_bounds_main(self, C, method, IBP, bound_lower, bound_upper, reuse_ibp, reuse_alpha, average_A, alpha_idx, need_A_only, update_mask)
   1502 apply_output_constraints_to = (
   1503     self.bound_opts['optimize_bound_args']['apply_output_constraints_to']
   1504 )
   1505 # This is for the final output bound.
   1506 # No need to pass in intermediate layer beta constraints.
-> 1507 ret = self.backward_general(
   1508     final, C,
   1509     bound_lower=bound_lower, bound_upper=bound_upper,
   1510     average_A=average_A, need_A_only=need_A_only,
   1511     unstable_idx=alpha_idx, update_mask=update_mask,
   1512     apply_output_constraints_to=apply_output_constraints_to)
   1514 if self.bound_opts['compare_crown_with_ibp']:
   1515     new_lower, new_upper = self.compare_with_IBP(final, lower=ret[0], upper=ret[1], C=C)

File /opt/app-root/lib64/python3.11/site-packages/auto_LiRPA/backward_bound.py:282, in backward_general(self, bound_node, C, start_backpropagation_at_node, bound_lower, bound_upper, average_A, need_A_only, unstable_idx, update_mask, verbose, apply_output_constraints_to, initial_As, initial_lb, initial_ub)
    280 else:
    281     start_shape = None
--> 282 A, lower_b, upper_b = l.bound_backward(
    283     lA, uA, *l.inputs,
    284     start_node=bound_node, unstable_idx=unstable_idx,
    285     start_shape=start_shape)
    287 # After propagation through this node, we delete its lA, uA variables.
    288 if bound_node.name != self.final_name:

File /opt/app-root/lib64/python3.11/site-packages/auto_LiRPA/operators/linear.py:891, in BoundMatMul.bound_backward(self, last_lA, last_uA, start_node, *x, **kwargs)
    889 if start_node is not None:
    890     self._start = start_node.name
--> 891 results = super().bound_backward(last_lA, last_uA, *x, **kwargs)
    892 lA_y = results[0][1][0].transpose(-1, -2) if results[0][1][0] is not None else None
    893 uA_y = results[0][1][1].transpose(-1, -2) if results[0][1][1] is not None else None

File /opt/app-root/lib64/python3.11/site-packages/auto_LiRPA/operators/linear.py:361, in BoundLinear.bound_backward(self, last_lA, last_uA, start_node, reduce_bias, *x, **kwargs)
    359 assert not self.use_seperate_weights_for_lower_and_upper_bounds
    360 # Obtain relaxations for matrix multiplication.
--> 361 [(lA_x, uA_x), (lA_y, uA_y)], lbias, ubias = self.bound_backward_with_weight(
    362     last_lA, last_uA, input_lb, input_ub, x[0], x[1],
    363     reduce_bias=reduce_bias, **kwargs)
    364 if has_bias:
    365     assert reduce_bias

File /opt/app-root/lib64/python3.11/site-packages/auto_LiRPA/operators/linear.py:457, in BoundLinear.bound_backward_with_weight(self, last_lA, last_uA, input_lb, input_ub, x, y, reduce_bias, **kwargs)
    450 def bound_backward_with_weight(self, last_lA, last_uA, input_lb, input_ub,
    451                                x, y, reduce_bias=True, **kwargs):
    452     # FIXME This is nonlinear. Move to `bivariate.py`.
    453 
    454     # Note: x and y are not tranposed or scaled, and we should avoid using them directly.
    455     # Use input_lb and input_ub instead.
    456     (alpha_l, beta_l, gamma_l,
--> 457      alpha_u, beta_u, gamma_u) = self.mul_helper.get_relaxation(
    458         *self._reshape(input_lb[0], input_ub[0], input_lb[1], input_ub[1]),
    459         self.opt_stage, getattr(self, 'alpha', None),
    460         getattr(self, '_start', None), middle=self.mul_middle)
    461     x_shape = input_lb[0].size()
    462     if reduce_bias:

File /opt/app-root/lib64/python3.11/site-packages/auto_LiRPA/operators/bivariate.py:90, in MulHelper.get_relaxation(x_l, x_u, y_l, y_u, opt_stage, alphas, start_name, middle)
     87     return MulHelper.interpolated_relaxation(
     88         x_l, x_u, y_l, y_u, alphas[ns][:2], alphas[ns][2:4])
     89 else:
---> 90     return MulHelper.interpolated_relaxation(
     91         x_l, x_u, y_l, y_u, middle=middle)

File /opt/app-root/lib64/python3.11/site-packages/auto_LiRPA/operators/bivariate.py:60, in MulHelper.interpolated_relaxation(x_l, x_u, y_l, y_u, r_l, r_u, middle)
     58     gamma_u = (y_l * x_u - y_u * x_l) * 0.5 - y_l * x_u
     59 else:
---> 60     alpha_l, beta_l, gamma_l = y_l, x_l, -y_l * x_l
     61     alpha_u, beta_u, gamma_u = y_u, x_l, -y_u * x_l
     62 return alpha_l, beta_l, gamma_l, alpha_u, beta_u, gamma_u

RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 2

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions