Recently we've finished the last Quant wrapper for Qwen3VL model - see Qwen3-VL: Implement quantization wrappers.
However, there are some remaning issues that I'd like to discuss below.
1. RoPE position embeddings
The Issue
At the time of writing we are computing RoPE position embeddings (sin and cos) in tico/quantization/wrapq/wrappers/qwen_vl/quant_text_model.py:
class QuantQwen3VLTextModel(QuantModuleBase):
...
def forward(...):
...
position_embeddings = self.rotary_emb(hidden_states, position_ids)
cos, sin = position_embeddings
position_embeddings = (
self._fq(cos, self.obs_cos),
self._fq(sin, self.obs_sin),
)
...
self.rotary_emb is simply a reference to Qwen3VLTextModel.rotary_emb which is an instance of Qwen3VLTextRotaryEmbedding. This reference is not wrapped into a PTQWrapper:
class QuantQwen3VLTextModel(QuantModuleBase):
...
def __init__(...):
...
self.rotary_emb = fp_model.rotary_emb
...
Why rotary_emb not wrapped. I was making the following assumption (see PR #535):
When tico.convert is called, it uses torch.export.export to capture the model as a static computation graph.
At this point, the dynamic self.rotary_emb(hidden_states, position_ids) call will be executed with concrete inputs,
resulting in static tensor values for cos and sin.
Later tests showed that the assumption was wrong (see PR #555).
position_embeddings = self.rotary_emb(hidden_states, position_ids) caused the following warning during the conversion to Circle (tico.convert): [QuantCheck] WARNING: 34 nodes without qparam detected (see logs)..
The problem is that computations inside self.rotary_emb(hidden_states, position_ids) do not transform into a static tensor and therefore penetrate the exported model graph.
As self.rotary_emb is not wrapped with PTQWrapper, it get to the exported model graph without quantization parameters (as indicated by the warning above).
Unquantized graph can't be compiled and executed by NPU.
The Solution
One obvious idea is to precompute RoPE position embeddings before conversion to Circle instead of computing them at inference time.
This can be achieved by simply moving the line position_embeddings = self.rotary_emb(hidden_states, position_ids) from QuantQwen3VLTextModel.forward to QuantQwen3VLTextModel.__init__.
The question arises about what to do with the hidden_states and position_ids arguments that we don't have in QuantQwen3VLTextModel.__init__.
hidden_states is not used much inside Qwen3VLTextRotaryEmbedding.forward: the method only references hidden_states.device and hidden_states.dtype.
The main concern is position_ids. We need to precompute it. For that we need the following data as showin in tico/quantization/wrapq/examples/qwen/quantize_model.py(compute_3d_position_ids):
input_ids - required only to compare against image_token_id to obtain the mask indicating the positions of image tokens in the prompt.
thw - image shape (num_temporal_patches, num_height_patches, num_width_patches); we already assume that THW is fixed at inference time in PTQConfig - see issue #560.
spatial_merge_size - constant.
image_token_id - constant.
Thus, we only need to fix the positions of image tokens in the prompt to be able to precompute RoPE position embeddings.
This means that we actually need to fix the following things at inference time:
- Number (e.g. 1 or 2 images/videos?) and type (image or video?) of visual pieces of data. E.g. 1 image, no videos.
- Shape of visual data - we already do that by defining
grid_thw in PTQConfig (see above).
- The starting position of visual data within the prompt. E.g. image tokens start at a fixed index
4.
That way visual tokens will always occupy some fixed section of the prompt.
What I suggest is to introduce another model parameter (let's name it, say, img_start_idx) to PTQConfig:
PTQConfig(
model_args={
"vision": {
"grid_thw": grid_thw,
"img_start_idx": img_start_idx
}
}
)
QuantQwen3VLTextModel.__init__ will read this img_start_idx from config and use it to precompute RoPE position embeddings.
2. Fusion of visual embeddings and text embeddings
The Issue
As mentioned in PR #555, merging visual embeddings (obtained from the vision encoder) with text embeddings is an operation that may cause troubles with the conversion to Circle.
If we assume arbitrary positions of visual data in the prompt, we have to use operations that cannot be converted to Circle (e.g. torch.Tensor.masked_scatter).
The only way that seems feasible to me, is to fix the positions of visual data in the prompt as it's assumed in tico/quantization/wrapq/wrappers/qwen_vl/quant_model.py(_fuse_text_n_image) where visual embeddings are simply placed to the first N rows of input embeddings tensor:
class QuantQwen3VLModel(QuantModuleBase):
...
@staticmethod
def _fuse_text_n_image(inputs_embeds, visual_embeds):
num_visual_tokens = visual_embeds.shape[0]
flat_inputs = inputs_embeds.view(-1, inputs_embeds.shape[-1])
flat_inputs[:num_visual_tokens] = visual_embeds
inputs_embeds = flat_inputs.view_as(inputs_embeds)
return inputs_embeds
...
This _fuse_text_n_image function loosely assumes that visual tokens always start at index 0 in the prompt.
In reality, however, visual tokens will likely start at some positive index. Here's an example of a prompt template in Chat Markup Language:
<|im_start|>user\n
<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>\n
<|im_start|>assistant\n'
Here image tokens are preceded by 4 other tokens (<|im_start|>, user, \n, <|vision_start|>) and therefore start at index 4 rather than 0.
The Solution
The solution is the same as before: fix the starting position of visual data in the prompt, store it in PTQConfig, and read it in QuantQwen3VLModel._fuse_text_n_image:
def _fuse_text_n_image(inputs_embeds, visual_embeds):
vision_args = qcfg.get_model_arg("vision", {})
img_start_idx = vision_args.get("img_start_idx")
num_visual_tokens = visual_embeds.shape[0]
flat_inputs = inputs_embeds.view(-1, inputs_embeds.shape[-1])
flat_inputs[img_start_idx:num_visual_tokens] = visual_embeds
inputs_embeds = flat_inputs.view_as(inputs_embeds)
return inputs_embeds
3. The Runtime
The above reasoning comes from the usual assumption that every single layer in the model must be a) converted to Circle and b) quantized.
But this depends on the runtime logic. I'm not aware of any specification of a runtime that is going to execute Qwen3VL model on the target device.
So, some part of the model may be executed by the CPU while some other part (e.g. vision blocks and decoder layers) will certainly be executed by an NPU.
Thus, if for example, RoPE computation and the fusion of image and text embeddings is going to be performed by the CPU, the above solutions may become irrelevant.
There's a runtime simulator https://github.com/Samsung/TICO/blob/main/tico/quantization/wrapq/examples/static_llama_layer_runtime.py for Llama.
Is there an implementation or specification of the runtime for Qwen3VL?
Recently we've finished the last Quant wrapper for Qwen3VL model - see Qwen3-VL: Implement quantization wrappers.
However, there are some remaning issues that I'd like to discuss below.
1. RoPE position embeddings
The Issue
At the time of writing we are computing RoPE position embeddings (sin and cos) in
tico/quantization/wrapq/wrappers/qwen_vl/quant_text_model.py:self.rotary_embis simply a reference toQwen3VLTextModel.rotary_embwhich is an instance ofQwen3VLTextRotaryEmbedding. This reference is not wrapped into aPTQWrapper:Why
rotary_embnot wrapped. I was making the following assumption (see PR #535):Later tests showed that the assumption was wrong (see PR #555).
position_embeddings = self.rotary_emb(hidden_states, position_ids)caused the following warning during the conversion to Circle (tico.convert):[QuantCheck] WARNING: 34 nodes without qparam detected (see logs)..The problem is that computations inside
self.rotary_emb(hidden_states, position_ids)do not transform into a static tensor and therefore penetrate the exported model graph.As
self.rotary_embis not wrapped withPTQWrapper, it get to the exported model graph without quantization parameters (as indicated by the warning above).Unquantized graph can't be compiled and executed by NPU.
The Solution
One obvious idea is to precompute RoPE position embeddings before conversion to Circle instead of computing them at inference time.
This can be achieved by simply moving the line
position_embeddings = self.rotary_emb(hidden_states, position_ids)fromQuantQwen3VLTextModel.forwardtoQuantQwen3VLTextModel.__init__.The question arises about what to do with the
hidden_statesandposition_idsarguments that we don't have inQuantQwen3VLTextModel.__init__.hidden_statesis not used much insideQwen3VLTextRotaryEmbedding.forward: the method only referenceshidden_states.deviceandhidden_states.dtype.The main concern is
position_ids. We need to precompute it. For that we need the following data as showin intico/quantization/wrapq/examples/qwen/quantize_model.py(compute_3d_position_ids):input_ids- required only to compare againstimage_token_idto obtain the mask indicating the positions of image tokens in the prompt.thw- image shape (num_temporal_patches, num_height_patches, num_width_patches); we already assume that THW is fixed at inference time inPTQConfig- see issue #560.spatial_merge_size- constant.image_token_id- constant.Thus, we only need to fix the positions of image tokens in the prompt to be able to precompute RoPE position embeddings.
This means that we actually need to fix the following things at inference time:
grid_thwinPTQConfig(see above).4.That way visual tokens will always occupy some fixed section of the prompt.
What I suggest is to introduce another model parameter (let's name it, say,
img_start_idx) toPTQConfig:QuantQwen3VLTextModel.__init__will read thisimg_start_idxfrom config and use it to precompute RoPE position embeddings.2. Fusion of visual embeddings and text embeddings
The Issue
As mentioned in PR #555, merging visual embeddings (obtained from the vision encoder) with text embeddings is an operation that may cause troubles with the conversion to Circle.
If we assume arbitrary positions of visual data in the prompt, we have to use operations that cannot be converted to Circle (e.g. torch.Tensor.masked_scatter).
The only way that seems feasible to me, is to fix the positions of visual data in the prompt as it's assumed in
tico/quantization/wrapq/wrappers/qwen_vl/quant_model.py(_fuse_text_n_image)where visual embeddings are simply placed to the firstNrows of input embeddings tensor:This
_fuse_text_n_imagefunction loosely assumes that visual tokens always start at index0in the prompt.In reality, however, visual tokens will likely start at some positive index. Here's an example of a prompt template in Chat Markup Language:
Here image tokens are preceded by 4 other tokens (
<|im_start|>,user,\n,<|vision_start|>) and therefore start at index 4 rather than 0.The Solution
The solution is the same as before: fix the starting position of visual data in the prompt, store it in
PTQConfig, and read it inQuantQwen3VLModel._fuse_text_n_image:3. The Runtime
The above reasoning comes from the usual assumption that every single layer in the model must be a) converted to Circle and b) quantized.
But this depends on the runtime logic. I'm not aware of any specification of a runtime that is going to execute Qwen3VL model on the target device.
So, some part of the model may be executed by the CPU while some other part (e.g. vision blocks and decoder layers) will certainly be executed by an NPU.
Thus, if for example, RoPE computation and the fusion of image and text embeddings is going to be performed by the CPU, the above solutions may become irrelevant.
There's a runtime simulator https://github.com/Samsung/TICO/blob/main/tico/quantization/wrapq/examples/static_llama_layer_runtime.py for Llama.
Is there an implementation or specification of the runtime for Qwen3VL?