-
Notifications
You must be signed in to change notification settings - Fork 138
[model, ops] feat: add Qwen3 sequence classification model and loss for embedding classification. #322
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for embedding classification by adding new data collators and a sequence classification head for the Qwen3 model. The changes are well-structured, but I've identified significant code duplication in the new data collators in veomni/data/data_collator.py. My review includes suggestions to refactor these classes using inheritance to improve maintainability and reduce redundancy. I also found some dead code that should be removed for clarity.
99f7c42 to
bd6e36f
Compare
bd6e36f to
99f7c42
Compare
veomni/data/data_collator.py
Outdated
|
|
||
|
|
||
| @dataclass | ||
| class ClassificationDataCollatorWithPositionIDs(DataCollator): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
split this to another MR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved to #376
veomni/models/loader.py
Outdated
|
|
||
| arch_name = get_model_arch_from_config(model_config) | ||
| model_type = model_config.model_type | ||
| if not force_use_huggingface: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rebase this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Already rebased.
veomni/ops/loss.py
Outdated
| **kwargs, | ||
| ) -> torch.Tensor: | ||
| # We don't use shift_labels | ||
| assert shift_labels is None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert can be skipped. do not use in production code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
resolved.
veomni/ops/loss.py
Outdated
| loss = None | ||
| logits = None | ||
|
|
||
| if labels is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
throw exception if label is None.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It has been implemented that a ValueError will be raised if the label is none now.
veomni/ops/loss.py
Outdated
| return loss, logits | ||
|
|
||
|
|
||
| def seqcls_token_loss_function( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this file is no longer there. can we follow the new way defined in https://github.com/ByteDance-Seed/VeOmni/blob/main/veomni/ops/fused_cross_entropy/__init__.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we can. Implemented using the latest way.
| ) | ||
|
|
||
| hidden_states = transformer_outputs.last_hidden_state | ||
| logits = self.score(hidden_states) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove this. no longer needed. we can just use the one from the loss_function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed.
| **kwargs: Unpack[FlashAttentionKwargs], | ||
| ) -> SequenceClassifierOutputWithPast: | ||
| transformer: Qwen3Model = getattr(self, self.base_model_prefix) | ||
| transformer_outputs: BaseModelOutputWithPast = transformer( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.model(...)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It has been revised to a simpler version as suggested.
veomni/ops/loss.py
Outdated
| labels: torch.Tensor, | ||
| num_items_in_batch: Optional[int] = None, | ||
| ignore_index: int = -100, | ||
| shift_labels: Optional[torch.Tensor] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove shift_labels. kwargs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed.
tests/data/test_seqcls_loss.py
Outdated
|
|
||
| loss, logits = m.seqcls_token_loss_function(hidden_states, weight, labels=labels, ignore_index=-100) | ||
|
|
||
| assert loss is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The unit tests for the loss function have been adjusted according to this document
veomni/ops/loss.py
Outdated
| ignore_index: int = -100, | ||
| shift_labels: Optional[torch.Tensor] = None, | ||
| **kwargs, | ||
| ) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add docstring.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added.
| weights=self.score.weight, | ||
| **kwargs, | ||
| ) | ||
| else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what if inference task
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added logic to calculate logits when no labels are provided, for compatibility with inference tasks.
tests/ops/test_seqcls_loss.py
Outdated
| @@ -0,0 +1,210 @@ | |||
| import math | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tbh I dont understand what this test does..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It verifies that the seq-classification loss uses the right function, handles masking and SP correctly, and produces the exact cross-entropy value expected from a manual calculation. Currently, this only involves manually constructing test cases and verifying whether the loss value is calculated correctly. Perhaps in the future, we can add a real test within the trainer.
piyifan123
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Coach257 do you want to take another look?
| hidden_states = kwargs.pop("hidden_states", None) | ||
| weights = kwargs.pop("weights", None) | ||
|
|
||
| assert hidden_states is not None or logits is not None, "hidden_states or logits must be provided." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replaced asserts with explicit ValueError.
tests/ops/test_seqcls_loss.py
Outdated
| """ | ||
| device = torch.device("cuda") | ||
| monkeypatch.setattr(m, "get_parallel_state", lambda: _FakePS(sp_enabled=False)) | ||
| ignore = -100 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import veomni constant IGNORE_INDEX
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
replaced
tests/ops/test_seqcls_loss.py
Outdated
| hidden_states=hidden_states, | ||
| weights=weights, | ||
| ) | ||
| expected = math.log(float(3)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we just write down the one line torch command to do the matmul + softmax + cross entropy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated.
tests/ops/test_seqcls_loss.py
Outdated
| logits = torch.zeros((1, 2, 3), device=device) | ||
| labels = torch.tensor([[ignore, 1]], device=device) | ||
| hidden_states = torch.zeros((1, 2, 5), device=device) | ||
| weights = torch.zeros((3, 5), device=device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we make it a matrix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, sure. Updated.
tests/ops/test_seqcls_loss.py
Outdated
| @@ -0,0 +1,210 @@ | |||
| import math | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check if we need to add tests/ops folder to https://github.com/ByteDance-Seed/VeOmni/blob/main/.github/workflows/gpu_unit_tests.yml
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we need it. I added the tests/ops directory.
* [docs] feat: add async doc in ulysses.md (#388) * [model] fix: Fused operator fix for qwen3vl (#378) * [perf, dist] feat: add zero2 in fsdp1 and use_orig_params configurable (#382) * [data,ci,docs] feat: add torchcodec-based video processing with ffmpeg support and comprehensive testing (#221) * [data,ci] test: enhance video_utils test suite with robust validation and benchmarks (#375) * [data, model] feat: support Qwen3-VL textual token-based time encoding (#386) * [config] feat: add MFU calculation for qwen3_vl_moe (#385) * [docs] fix: Optimize document links in Markdown rendering (#380) * [model, ops] feat: add Qwen3 sequence classification model and loss for embedding classification. (#322) * [dist, data] fix: init parallel state in data collator post init to avoid worker processing getting single process state (#383) See merge request: !78
feat: Add a data collator to support embedding classification.