Skip to content

Quantized OpenVLA 7B (4-bit) runs slower than default model #3236

@arunmadhusud

Description

@arunmadhusud

Hi ,

I quantized the OpenVLA 7B model using GemliteUIntXWeightOnlyConfig(bit_width=4, group_size=128).
However, the quantized model runs ~2× slower than the default model.

Environment

torch==2.7.1+cu126
torchao==0.14.0
transformers==4.51.3
GPU: A6000

Repro script

import torch, gc
from transformers import AutoModelForVision2Seq, AutoProcessor
from PIL import Image
from torchao.quantization.quant_api import GemliteUIntXWeightOnlyConfig, quantize_

EXCLUDE_TAGS = ["lm_head", "vision_backbone"]

def filter_fct(module, fqn):
    if not isinstance(module, torch.nn.Linear):
        return False
    return not any([ex in fqn for ex in EXCLUDE_TAGS])

@torch.no_grad()
def benchmark_in_ms(model_forward, warmup, iters, **inputs):
    for _ in range(warmup):
        action = model_forward.predict_action(**inputs, unnorm_key="libero_spatial", do_sample=False)
    torch.cuda.synchronize()
    start, end = torch.cuda.Event(True), torch.cuda.Event(True)
    start.record()
    for _ in range(iters):
        action = model_forward.predict_action(**inputs, unnorm_key="libero_spatial", do_sample=False)
    end.record()
    torch.cuda.synchronize()
    return start.elapsed_time(end) / iters

def load_model(model_path, device, dtype=torch.bfloat16):
    processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True, use_fast=True)
    model = AutoModelForVision2Seq.from_pretrained(
        model_path, torch_dtype=dtype, device_map=device, trust_remote_code=True
    )
    model.eval()
    return model, processor

def benchmark_openvla(model_path, image_path, warmup=2, iters=10, device='cuda:0', dtype=torch.bfloat16):
    device_idx = torch.device(device)
    SYSTEM_PROMPT = "A chat between a curious user and an AI assistant. The assistant gives helpful, detailed, polite answers."
    instruction = "pick_up_the_black_bowl_between_the_plate_and_the_ramekin_and_place_it_on_the_plate"
    prompt = f"{SYSTEM_PROMPT} USER: What action should the robot take to {instruction.lower()}?"
    image = Image.open(image_path)

    # ----- Base model -----
    model_raw, processor = load_model(model_path, device_idx, dtype)
    inputs = processor(prompt, image).to(device_idx, dtype=dtype)
    model_raw = torch.compile(model_raw, mode="reduce-overhead", fullgraph=True)
    raw_time = benchmark_in_ms(model_raw, warmup, iters, **inputs)
    print(f"[RESULT] Raw model avg time: {raw_time:.2f} ms")

    del model_raw; torch.cuda.empty_cache(); gc.collect()

    # ----- Quantized model -----
    model_quant, processor = load_model(model_path, device_idx, dtype)
    config = GemliteUIntXWeightOnlyConfig(bit_width=4, group_size=128, mode="weight_only")
    quantize_(model_quant, config, filter_fn=filter_fct)
    inputs = processor(prompt, image).to(device_idx, dtype=dtype)
    model_quant = torch.compile(model_quant, mode="reduce-overhead", fullgraph=True)
    quant_time = benchmark_in_ms(model_quant, warmup, iters, **inputs)
    print(f"[RESULT] Quantized 4-bit model avg time: {quant_time:.2f} ms")

if __name__ == "__main__":
    benchmark_openvla(
        model_path="openvla/openvla-7b-finetuned-libero-spatial",
        image_path="image.jpg",
        warmup=2,
        iters=10,
        device='cuda:0',
    )

Results

Average time for default model: 160 ms
Average time for quantized model: 290 ms

What could be the reason for the slower inference of the quantized model?
Am I doing anything wrong in my setup? Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions