-
Notifications
You must be signed in to change notification settings - Fork 341
Add support for DebertaV2ForSequenceClassification #746
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Comparions from my Mac |
kozistr
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Leave a quick review! Great work!
| (Config::DebertaV2(config), Device::Cpu | Device::Metal(_)) => { | ||
| tracing::info!("Starting DebertaV2 model on {:?}", device); | ||
| Ok(Box::new(DebertaV2Model::load(vb, &config, model_type).s()?)) | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about adding CUDA support too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. I've enabled the feature but I don't have an easily accessible set-up to test this with CUDA. Could you help me with testing to see if everything checks out?
This commit adapts text-embeddings-inference for NVIDIA Jetson Orin (SM87) and L4 GPU (SM89), and integrates valuable community PRs. Changes: 1. SM87/SM89 CUDA Support - Added compute capability 8.7 and 8.9 support - Modified Dockerfile-cuda-all for multi-arch builds - Updated compute_cap.rs for SM87/89 detection Files: Dockerfile-cuda-all, cuda-all-entrypoint.sh, compute_cap.rs 2. PR huggingface#730: Qwen3 Reranker Support - Added classification head for Qwen3 reranking - Implemented template formatting system for chat-based reranking Files: models/qwen3.rs, core/templates.rs, core/lib.rs 3. PR huggingface#787: Batch Notification Performance Optimization - Implemented AtomicUsize counter for batch processing - Reduced unnecessary notify_one() calls - Only last request in batch triggers thread notification Files: core/infer.rs, router/http/server.rs, router/grpc/server.rs 4. PR huggingface#753: GeLU Activation Consistency Fix - Changed Gelu from approximate (gelu) to exact (gelu_erf) - Added NewGelu variant for backward compatibility Files: layers/linear.rs 5. PR huggingface#790: StaticEmbedding Model Support - Added support for 0_StaticEmbedding/ directory structure - Implemented fallback loading for model weights and tokenizer - Default to Mean pooling for StaticEmbedding models Files: models/static_embedding.rs (new), lib.rs, download.rs, router/lib.rs 6. PR huggingface#746: DebertaV2 Sequence Classification Support - Complete DebertaV2 model implementation - Support for sequence classification tasks (e.g., Llama Prompt Guard) - CPU and CUDA device support Files: models/debertav2.rs (new), lib.rs, models/mod.rs All changes have been tested and compile successfully with: cargo check --all-targets Compilation verified with CUDA support: cargo install --path router -F candle-cuda Target Hardware: NVIDIA Jetson Orin AGX (SM87), L4 GPU (SM89) Date: January 5, 2026
What does this PR do?
This PR adds support for the DebertaV2SequenceClassification model, effectively closing #354 #281 #199
Shoutout to @kozistr for providing an initial set of reviews.
I have verified that outputs are identical on my Mac. I could use some help testing this on a CUDA machine if anyone can help out!
Fixes #354 #281 #199
PSA: The vast majority of this code has been borrowed from the great work done by @BradyBonnette in huggingface/candle#2743 <3
Before submitting
instasnapshots?Who can review?
@Narsil @alvarobartt @kozistr