-
Notifications
You must be signed in to change notification settings - Fork 142
Open
Description
Hello! I'm trying hard to convert ViT to int8 using MQBench.
when I applied adaround config to vit( from timm library), it showed error related to usage of nn.parameter.
this error occur after clibrartion and in reconstruction_ptq step.
more in depth, it occurs during fx.graphmodule conversion step in calling extract_subgraph().
ViT usually contains cls_token, pos_embed as nn.parameter, not nn.Module.
how to avoid or solve this case?
- imagenet_example's resnet 18 model contains only nn.Module so it didnt make errors.
- I tried to wrap nn.parameter to nn.module by creating a new nn.module class in the timm library like
class ParameterWrapper(nn.Module):
def __init__(self, param):
super(ParameterWrapper, self).__init__()
self.param = nn.Parameter(param)
def forward(self):
return self.param
#original code : self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
self.cls_token = ParameterWrapper(torch.zeros(1, 1, embed_dim)) if class_token else None
but actually it makes another error after extract_subgraph(), in subgraph_reconstruction() step.
Please help me!
Metadata
Metadata
Assignees
Labels
No labels