-
Notifications
You must be signed in to change notification settings - Fork 358
Open
Description
Hi ,
I quantized the OpenVLA 7B model using GemliteUIntXWeightOnlyConfig(bit_width=4, group_size=128).
However, the quantized model runs ~2× slower than the default model.
Environment
torch==2.7.1+cu126
torchao==0.14.0
transformers==4.51.3
GPU: A6000
Repro script
import torch, gc
from transformers import AutoModelForVision2Seq, AutoProcessor
from PIL import Image
from torchao.quantization.quant_api import GemliteUIntXWeightOnlyConfig, quantize_
EXCLUDE_TAGS = ["lm_head", "vision_backbone"]
def filter_fct(module, fqn):
if not isinstance(module, torch.nn.Linear):
return False
return not any([ex in fqn for ex in EXCLUDE_TAGS])
@torch.no_grad()
def benchmark_in_ms(model_forward, warmup, iters, **inputs):
for _ in range(warmup):
action = model_forward.predict_action(**inputs, unnorm_key="libero_spatial", do_sample=False)
torch.cuda.synchronize()
start, end = torch.cuda.Event(True), torch.cuda.Event(True)
start.record()
for _ in range(iters):
action = model_forward.predict_action(**inputs, unnorm_key="libero_spatial", do_sample=False)
end.record()
torch.cuda.synchronize()
return start.elapsed_time(end) / iters
def load_model(model_path, device, dtype=torch.bfloat16):
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True, use_fast=True)
model = AutoModelForVision2Seq.from_pretrained(
model_path, torch_dtype=dtype, device_map=device, trust_remote_code=True
)
model.eval()
return model, processor
def benchmark_openvla(model_path, image_path, warmup=2, iters=10, device='cuda:0', dtype=torch.bfloat16):
device_idx = torch.device(device)
SYSTEM_PROMPT = "A chat between a curious user and an AI assistant. The assistant gives helpful, detailed, polite answers."
instruction = "pick_up_the_black_bowl_between_the_plate_and_the_ramekin_and_place_it_on_the_plate"
prompt = f"{SYSTEM_PROMPT} USER: What action should the robot take to {instruction.lower()}?"
image = Image.open(image_path)
# ----- Base model -----
model_raw, processor = load_model(model_path, device_idx, dtype)
inputs = processor(prompt, image).to(device_idx, dtype=dtype)
model_raw = torch.compile(model_raw, mode="reduce-overhead", fullgraph=True)
raw_time = benchmark_in_ms(model_raw, warmup, iters, **inputs)
print(f"[RESULT] Raw model avg time: {raw_time:.2f} ms")
del model_raw; torch.cuda.empty_cache(); gc.collect()
# ----- Quantized model -----
model_quant, processor = load_model(model_path, device_idx, dtype)
config = GemliteUIntXWeightOnlyConfig(bit_width=4, group_size=128, mode="weight_only")
quantize_(model_quant, config, filter_fn=filter_fct)
inputs = processor(prompt, image).to(device_idx, dtype=dtype)
model_quant = torch.compile(model_quant, mode="reduce-overhead", fullgraph=True)
quant_time = benchmark_in_ms(model_quant, warmup, iters, **inputs)
print(f"[RESULT] Quantized 4-bit model avg time: {quant_time:.2f} ms")
if __name__ == "__main__":
benchmark_openvla(
model_path="openvla/openvla-7b-finetuned-libero-spatial",
image_path="image.jpg",
warmup=2,
iters=10,
device='cuda:0',
)Results
Average time for default model: 160 ms
Average time for quantized model: 290 ms
What could be the reason for the slower inference of the quantized model?
Am I doing anything wrong in my setup? Thanks!
Metadata
Metadata
Assignees
Labels
No labels