Skip to content

Commit c09c876

Browse files
committed
Add special case to avoid quantizing conv in Moonshine
- Add a define to prevent quantizing the first conv layers in the Moonshine preprocessor - Add options to enable rotary positional embeddings in the Transformer Encoder spec.
1 parent 6373848 commit c09c876

File tree

3 files changed

+28
-0
lines changed

3 files changed

+28
-0
lines changed

CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@ option(BUILD_TESTS "Compile the tests" OFF)
2222
option(BUILD_SHARED_LIBS "Build shared libraries" ON)
2323
option(WITH_TENSOR_PARALLEL "Compile with NCCL and MPI backend" OFF)
2424
option(WITH_FLASH_ATTN "Compile with Flash Attention 2" OFF)
25+
option(MOONSHINE "Compile with moonshine specializations" OFF)
26+
27+
if (MOONSHINE)
28+
add_definitions(-DMOONSHINE)
29+
endif()
2530

2631
if(ENABLE_PROFILING)
2732
message(STATUS "Enable profiling support")

python/ctranslate2/specs/transformer_spec.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@ def __init__(
2222
relative_attention_bias: bool = False,
2323
ffn_glu: bool = False,
2424
rms_norm: bool = False,
25+
rotary_dim: Optional[int] = None,
26+
rotary_interleave: bool = True,
27+
rotary_scaling_type: Optional[attention_spec.RotaryScalingType] = None,
28+
rotary_scaling_factor: float = 1,
29+
rotary_base: float = 10000,
2530
multi_query_attention: bool = False,
2631
):
2732
"""Initializes a Transformer encoder specification.
@@ -66,6 +71,11 @@ def __init__(
6671
relative_attention_bias=relative_attention_bias,
6772
ffn_glu=ffn_glu,
6873
rms_norm=rms_norm,
74+
rotary_dim=rotary_dim,
75+
rotary_interleave=rotary_interleave,
76+
rotary_scaling_type=rotary_scaling_type,
77+
rotary_scaling_factor=rotary_scaling_factor,
78+
rotary_base=rotary_base,
6979
num_heads_kv=1 if multi_query_attention else None,
7080
)
7181
for _ in range(num_layers)
@@ -251,6 +261,11 @@ def __init__(
251261
relative_attention_bias=False,
252262
ffn_glu=False,
253263
rms_norm=False,
264+
rotary_dim=None,
265+
rotary_interleave=True,
266+
rotary_scaling_type=None,
267+
rotary_scaling_factor=1,
268+
rotary_base=10000,
254269
num_heads_kv=None,
255270
sliding_window=None,
256271
):
@@ -259,6 +274,11 @@ def __init__(
259274
relative_position=relative_position,
260275
relative_attention_bias=relative_attention_bias,
261276
rms_norm=rms_norm,
277+
rotary_dim=rotary_dim,
278+
rotary_interleave=rotary_interleave,
279+
rotary_scaling_type=rotary_scaling_type,
280+
rotary_scaling_factor=rotary_scaling_factor,
281+
rotary_base=rotary_base,
262282
num_heads_kv=num_heads_kv,
263283
sliding_window=sliding_window,
264284
)

src/models/model.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,9 @@ namespace ctranslate2 {
213213
if (device == Device::CUDA
214214
#ifdef CT2_WITH_DNNL
215215
|| true
216+
#endif
217+
#ifdef MOONSHINE
218+
|| true
216219
#endif
217220
) {
218221
variable_weight_dtype = float_dtype;

0 commit comments

Comments
 (0)