Skip to content

Commit 034035e

Browse files
Add DeepseekV3 Export (#47)
1 parent 2102d54 commit 034035e

File tree

8 files changed

+29
-1
lines changed

8 files changed

+29
-1
lines changed

docs/source/onnx/overview.mdx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra
4141
- Deberta
4242
- Deberta-v2
4343
- Decision Transformer
44+
- DeepSeek-V3
4445
- Deit
4546
- Detr
4647
- DINOv2

optimum/exporters/onnx/model_configs.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
ASTDummyAudioInputGenerator,
5656
BartDummyTextInputGenerator,
5757
BloomDummyPastKeyValuesGenerator,
58+
DeepSeekV3DummyPastKeyValuesGenerator,
5859
Dinov2DummyInputGenerator,
5960
DummyCodegenDecoderTextInputGenerator,
6061
DummyDecisionTransformerInputGenerator,
@@ -441,6 +442,13 @@ class ArceeOnnxConfig(LlamaOnnxConfig):
441442
NORMALIZED_CONFIG_CLASS = NormalizedTextConfigWithGQA
442443

443444

445+
@register_tasks_manager_onnx("deepseek_v3", *COMMON_TEXT_GENERATION_TASKS)
446+
class DeepSeekV3OnnxConfig(LlamaOnnxConfig):
447+
MIN_TRANSFORMERS_VERSION = version.parse("4.51.0")
448+
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, DeepSeekV3DummyPastKeyValuesGenerator)
449+
DUMMY_PKV_GENERATOR_CLASS = DeepSeekV3DummyPastKeyValuesGenerator
450+
451+
444452
@register_tasks_manager_onnx("cohere", *COMMON_TEXT_GENERATION_TASKS)
445453
class CohereOnnxConfig(LlamaOnnxConfig):
446454
MIN_TRANSFORMERS_VERSION = version.parse("4.38.0")

optimum/exporters/onnx/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
MODEL_TYPES_REQUIRING_POSITION_IDS = {
7272
"arcee",
7373
"codegen",
74+
"deepseek_v3",
7475
"cohere",
7576
"falcon",
7677
"gemma",

optimum/onnxruntime/modeling_decoder.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,11 +208,16 @@ def __init__(
208208
self.embed_size_per_head = self.config.head_dim
209209
elif self.config.model_type == "gpt_bigcode":
210210
self.embed_size_per_head = self.config.hidden_size // self.config.num_attention_heads * 2
211+
elif self.config.model_type == "deepseek_v3":
212+
# For deepseek_v3, keys and values have different head dimensions
213+
self.qk_head_dim = self.config.qk_rope_head_dim + self.config.qk_nope_head_dim
214+
self.v_head_dim = self.config.v_head_dim
211215
else:
212216
self.embed_size_per_head = self.config.hidden_size // self.config.num_attention_heads
213217

214218
if self.config.model_type in {
215219
"arcee",
220+
"deepseek_v3",
216221
"cohere",
217222
"gemma",
218223
"helium",
@@ -345,6 +350,10 @@ def forward(
345350
v_shape = (batch_size * self.num_key_value_heads, 0, self.embed_size_per_head)
346351
elif self.config.model_type == "gpt_bigcode" and self.config.multi_query:
347352
k_shape = v_shape = (batch_size, 0, self.embed_size_per_head)
353+
elif self.config.model_type == "deepseek_v3":
354+
# For deepseek_v3, keys and values have different head dimensions
355+
k_shape = (batch_size, self.num_key_value_heads, 0, self.qk_head_dim)
356+
v_shape = (batch_size, self.num_key_value_heads, 0, self.v_head_dim)
348357
else:
349358
k_shape = v_shape = (batch_size, self.num_key_value_heads, 0, self.embed_size_per_head)
350359
k_tensor = torch.zeros(k_shape, dtype=self.dtype, device=self.device)
@@ -375,6 +384,10 @@ def forward(
375384
elif self.config.model_type == "gpt_bigcode" and self.config.multi_query:
376385
embed_size_per_head = past_key_values[0].shape[-1]
377386
k_shape = v_shape = (batch_size, pkv_seq_len + seq_len, embed_size_per_head)
387+
elif self.config.model_type == "deepseek_v3":
388+
# For deepseek_v3, keys and values have different head dimensions
389+
k_shape = (batch_size, self.num_key_value_heads, pkv_seq_len + seq_len, self.qk_head_dim)
390+
v_shape = (batch_size, self.num_key_value_heads, pkv_seq_len + seq_len, self.v_head_dim)
378391
else:
379392
embed_size_per_head = past_key_values[0].shape[-1]
380393
k_shape = v_shape = (batch_size, self.num_key_value_heads, pkv_seq_len + seq_len, embed_size_per_head)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ classifiers = [
2929
"Topic :: Scientific/Engineering :: Artificial Intelligence",
3030
]
3131
dependencies = [
32-
"optimum @ git+https://github.com/huggingface/optimum",
32+
"optimum @ git+https://github.com/huggingface/optimum@add-deepseekv3-dummypastkeyvaluesgenerator",
3333
"transformers>=4.36,<4.54.0",
3434
"onnx",
3535
]

tests/exporters/onnx/utils_tests.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
"deberta-v2": "hf-internal-testing/tiny-random-DebertaV2Model",
7777
"decision_transformer": "edbeeching/decision-transformer-gym-hopper-medium",
7878
"deit": "hf-internal-testing/tiny-random-DeiTModel",
79+
"deepseek_v3": "hf-internal-testing/tiny-random-DeepseekV3ForCausalLM",
7980
"dinov2": "hf-internal-testing/tiny-random-Dinov2Model",
8081
"donut": "fxmarty/tiny-doc-qa-vision-encoder-decoder",
8182
"donut-swin": "hf-internal-testing/tiny-random-DonutSwinModel",

tests/onnxruntime/test_decoder.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
ArceeOnnxConfig,
3232
BloomOnnxConfig,
3333
CohereOnnxConfig,
34+
DeepSeekV3OnnxConfig,
3435
GemmaOnnxConfig,
3536
GraniteOnnxConfig,
3637
HeliumOnnxConfig,
@@ -122,6 +123,8 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin):
122123
SUPPORTED_ARCHITECTURES.append("internlm2")
123124
if is_transformers_version(">=", str(SmolLM3OnnxConfig.MIN_TRANSFORMERS_VERSION)):
124125
SUPPORTED_ARCHITECTURES.append("smollm3")
126+
if is_transformers_version(">=", str(DeepSeekV3OnnxConfig.MIN_TRANSFORMERS_VERSION)):
127+
SUPPORTED_ARCHITECTURES.append("deepseek_v3")
125128
if is_transformers_version(">=", str(StableLMOnnxConfig.MIN_TRANSFORMERS_VERSION)):
126129
SUPPORTED_ARCHITECTURES.append("stablelm")
127130

tests/onnxruntime/testing_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
"data2vec-audio": "hf-internal-testing/tiny-random-Data2VecAudioModel",
5151
"deberta": "hf-internal-testing/tiny-random-DebertaModel",
5252
"deberta-v2": "hf-internal-testing/tiny-random-DebertaV2Model",
53+
"deepseek_v3": "hf-internal-testing/tiny-random-DeepseekV3ForCausalLM",
5354
"deit": "hf-internal-testing/tiny-random-DeiTModel",
5455
"donut": "fxmarty/tiny-doc-qa-vision-encoder-decoder",
5556
"detr": "hf-internal-testing/tiny-random-detr",

0 commit comments

Comments
 (0)