Skip to content

Commit 8d1f4ea

Browse files
committed
Update
1 parent c82213d commit 8d1f4ea

File tree

4 files changed

+46
-63
lines changed

4 files changed

+46
-63
lines changed

dlib/dnn/transformer.h

Lines changed: 23 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -32,29 +32,29 @@ namespace dlib
3232
namespace canonical_transformer
3333
{
3434

35-
template <long seq_len, long d_model, long num_heads, typename SUBNET>
36-
using query = reshape_to<num_heads, seq_len, d_model / num_heads,
35+
template <long d_model, long num_heads, typename SUBNET>
36+
using query = reshape_to<num_heads, -1, d_model / num_heads,
3737
linear_no_bias<d_model, SUBNET>>;
3838

39-
template <long seq_len, long d_model, long num_heads, typename SUBNET>
40-
using key = reshape_to<num_heads, seq_len, d_model / num_heads,
39+
template <long d_model, long num_heads, typename SUBNET>
40+
using key = reshape_to<num_heads, -1, d_model / num_heads,
4141
linear_no_bias<d_model, SUBNET>>;
4242

43-
template <long seq_len, long d_model, long num_heads, typename SUBNET>
44-
using value = reshape_to<num_heads, seq_len, d_model / num_heads,
43+
template <long d_model, long num_heads, typename SUBNET>
44+
using value = reshape_to<num_heads, -1, d_model / num_heads,
4545
linear_no_bias<d_model, SUBNET>>;
4646

4747
template <template <typename> class ACT, template <typename> class DO,
48-
long seq_len, long d_model, long num_heads, typename SUBNET>
48+
long d_model, long num_heads, typename SUBNET>
4949
using multihead_attention =
50-
DO<linear_no_bias<d_model, reshape_to<1, seq_len, d_model,
50+
DO<linear_no_bias<d_model, reshape_to<1, -1, d_model,
5151
multm_prev3<softmaxm<tril_mask<
5252
scale_weights<d_model / num_heads,
5353
multm_prev4<
54-
rope<query<seq_len, d_model, num_heads, skip1<
54+
rope<query<d_model, num_heads, skip1<
5555
tag4<transpose<
56-
rope<key<seq_len, d_model, num_heads, skip2<
57-
tag3<value<seq_len, d_model, num_heads,
56+
rope<key<d_model, num_heads, skip2<
57+
tag3<value<d_model, num_heads,
5858
tag2<SUBNET>>>>>>>>>>>>>>>>>>>;
5959

6060
template <template <typename> class ACT, template <typename> class DO,
@@ -68,29 +68,29 @@ namespace dlib
6868
tag7<silu<linear<(d_model * 2) / 7, tag6<SUBNET>>>>>>>>>;
6969

7070
template <template <typename> class ACT, template <typename> class DO,
71-
long seq_len, long d_model, long num_heads, typename SUBNET>
71+
long d_model, long num_heads, typename SUBNET>
7272
using transformer_block =
7373
add_prev5<std_ffn<ACT, DO, d_model, rms_norm<tag5<
74-
add_prev1<multihead_attention<ACT, DO, seq_len, d_model, num_heads, rms_norm<tag1<SUBNET>>>>>>>>;
74+
add_prev1<multihead_attention<ACT, DO, d_model, num_heads, rms_norm<tag1<SUBNET>>>>>>>>;
7575

7676
template<long remaining_layers, template <typename> class ACT, template <typename> class DO,
77-
long seq_len, long d_model, long num_heads, typename SUBNET, typename enabled = void>
77+
long d_model, long num_heads, typename SUBNET, typename enabled = void>
7878
struct transformer_stack_impl
7979
{
80-
using type = transformer_block<ACT, DO, seq_len, d_model, num_heads,
81-
typename transformer_stack_impl<remaining_layers - 1, ACT, DO, seq_len, d_model, num_heads, SUBNET>::type>;
80+
using type = transformer_block<ACT, DO, d_model, num_heads,
81+
typename transformer_stack_impl<remaining_layers - 1, ACT, DO, d_model, num_heads, SUBNET>::type>;
8282
};
8383

8484
template<template <typename> class ACT, template <typename> class DO,
85-
long seq_len, long d_model, long num_heads, typename SUBNET>
86-
struct transformer_stack_impl<0, ACT, DO, seq_len, d_model, num_heads, SUBNET, void>
85+
long d_model, long num_heads, typename SUBNET>
86+
struct transformer_stack_impl<0, ACT, DO, d_model, num_heads, SUBNET, void>
8787
{
8888
using type = tag10<SUBNET>;
8989
};
9090

9191
template<long num_layers, template <typename> class ACT, template <typename> class DO,
92-
long seq_len, long d_model, long num_heads, typename SUBNET>
93-
using transformer_stack = typename transformer_stack_impl<num_layers, ACT, DO, seq_len, d_model, num_heads, SUBNET>::type;
92+
long d_model, long num_heads, typename SUBNET>
93+
using transformer_stack = typename transformer_stack_impl<num_layers, ACT, DO, d_model, num_heads, SUBNET>::type;
9494

9595
} // namespace std_transformer
9696

@@ -179,7 +179,6 @@ namespace dlib
179179
using l_net_type = L_NET;
180180

181181
explicit hrm_() :
182-
seq_len(0),
183182
hidden_dim(0),
184183
learning_rate_multiplier(1.0)
185184
{
@@ -190,7 +189,6 @@ namespace dlib
190189
l_net(other.l_net),
191190
z_h_init(other.z_h_init),
192191
z_l_init(other.z_l_init),
193-
seq_len(other.seq_len),
194192
hidden_dim(other.hidden_dim),
195193
learning_rate_multiplier(other.learning_rate_multiplier)
196194
{
@@ -203,7 +201,6 @@ namespace dlib
203201
l_net = other.l_net;
204202
z_h_init = other.z_h_init;
205203
z_l_init = other.z_l_init;
206-
seq_len = other.seq_len;
207204
hidden_dim = other.hidden_dim;
208205
learning_rate_multiplier = other.learning_rate_multiplier;
209206
}
@@ -215,8 +212,7 @@ namespace dlib
215212
{
216213
const tensor& input = sub.get_output();
217214

218-
// Store dimensions for initialization
219-
seq_len = input.nr();
215+
// Store dimension for initialization
220216
hidden_dim = input.nc();
221217

222218
// Initialize hidden states with truncated normal (std=1, trunc=2)
@@ -229,6 +225,7 @@ namespace dlib
229225
const tensor& x = sub.get_output();
230226
const long batch_size = x.num_samples();
231227
const long k = x.k();
228+
const long seq_len = x.nr();
232229

233230
// Allocate working tensors with proper batch size
234231
z_h_current.copy_size(x);
@@ -356,7 +353,6 @@ namespace dlib
356353
serialize(item.l_net, out);
357354
serialize(item.z_h_init, out);
358355
serialize(item.z_l_init, out);
359-
serialize(item.seq_len, out);
360356
serialize(item.hidden_dim, out);
361357
serialize(item.learning_rate_multiplier, out);
362358
}
@@ -372,7 +368,6 @@ namespace dlib
372368
deserialize(item.l_net, in);
373369
deserialize(item.z_h_init, in);
374370
deserialize(item.z_l_init, in);
375-
deserialize(item.seq_len, in);
376371
deserialize(item.hidden_dim, in);
377372
deserialize(item.learning_rate_multiplier, in);
378373
}
@@ -449,7 +444,6 @@ namespace dlib
449444
resizable_tensor z_l_init;
450445

451446
// Dimensions and learning rate
452-
long seq_len;
453447
long hidden_dim;
454448
double learning_rate_multiplier;
455449

@@ -473,7 +467,7 @@ namespace dlib
473467

474468
// Gate network: produces raw logits for expert selection
475469
template <long num_experts, template <typename> class DO, typename SUBNET>
476-
using gate = fc<num_experts, DO<leaky_relu<fc<num_experts * 8, SUBNET>>>>;
470+
using gate = fc<num_experts, DO<leaky_relu<fc<num_experts * 8, avg_pool_everything<SUBNET>>>>>;
477471

478472
struct training_mode_tag {};
479473
struct inference_mode_tag {};

examples/slm_advanced_train_ex.cpp

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ namespace dlib
6666
* @param num_layers Number of transformer layers
6767
* @param num_heads Number of attention heads
6868
* @param embedding_dim Dimension of token embeddings
69-
* @param max_seq_len Maximum sequence length
7069
* @param activation_func Activation function type
7170
* @param dropout_policy Dropout regularization policy
7271
*/
@@ -75,7 +74,6 @@ namespace dlib
7574
long num_layers = 6,
7675
long num_heads = 8,
7776
long embedding_dim = 512,
78-
long max_seq_len = 300,
7977
template <typename> class activation_func = gelu,
8078
template <typename> class dropout_policy = dropout_10
8179
>
@@ -85,7 +83,6 @@ namespace dlib
8583
static constexpr long NUM_LAYERS = num_layers;
8684
static constexpr long NUM_HEADS = num_heads;
8785
static constexpr long EMBEDDING_DIM = embedding_dim;
88-
static constexpr long MAX_SEQ_LEN = max_seq_len;
8986

9087
// Compile-time validation of model configuration
9188
struct validation {
@@ -98,10 +95,10 @@ namespace dlib
9895
template<bool is_training>
9996
using network_type = std::conditional_t<is_training,
10097
classification_head<VOCAB_SIZE, EMBEDDING_DIM,
101-
transformer_stack<NUM_LAYERS, activation_func, dropout_policy, MAX_SEQ_LEN, EMBEDDING_DIM, NUM_HEADS,
98+
transformer_stack<NUM_LAYERS, activation_func, dropout_policy, EMBEDDING_DIM, NUM_HEADS,
10299
embeddings<VOCAB_SIZE, EMBEDDING_DIM, input<matrix<int, 0, 1>>>>>,
103100
classification_head<VOCAB_SIZE, EMBEDDING_DIM,
104-
transformer_stack<NUM_LAYERS, activation_func, multiply, MAX_SEQ_LEN, EMBEDDING_DIM, NUM_HEADS,
101+
transformer_stack<NUM_LAYERS, activation_func, multiply, EMBEDDING_DIM, NUM_HEADS,
105102
embeddings<VOCAB_SIZE, EMBEDDING_DIM, input<matrix<int, 0, 1>>>>>>;
106103

107104
struct model_info {
@@ -111,8 +108,7 @@ namespace dlib
111108
<< "- vocabulary size: " << VOCAB_SIZE << "\n"
112109
<< "- layers: " << NUM_LAYERS << "\n"
113110
<< "- attention heads: " << NUM_HEADS << "\n"
114-
<< "- embedding dimension: " << EMBEDDING_DIM << "\n"
115-
<< "- sequence length: " << MAX_SEQ_LEN;
111+
<< "- embedding dimension: " << EMBEDDING_DIM;
116112
return ss.str();
117113
}
118114
};
@@ -309,8 +305,7 @@ int main(int argc, char** argv)
309305
num_tokens, // vocab_size
310306
num_layers, // number of layers
311307
num_heads, // number of attention heads
312-
embedding_dim, // embedding dimension
313-
max_seq_len // maximum sequence length
308+
embedding_dim // embedding dimension
314309
>;
315310

316311
// Load internal dataset

examples/slm_chatbot_ex.cpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,10 @@ namespace dlib
5454

5555
// Complete transformer block with MoE-based feed-forward layer
5656
template <template <typename> class ACT, template <typename> class DO,
57-
long seq_len, long d_model, long num_heads, typename MODE, typename SUBNET>
57+
long d_model, long num_heads, typename MODE, typename SUBNET>
5858
using trans_moe_block =
5959
moe_ffn<expert_net_type<DO, d_model>, 4, 0, MODE, DO,
60-
add_prev1<multihead_attention<ACT, DO, seq_len, d_model, num_heads, rms_norm<tag1<SUBNET>>>>>;
60+
add_prev1<multihead_attention<ACT, DO, d_model, num_heads, rms_norm<tag1<SUBNET>>>>>;
6161

6262
// Classification head for next-token prediction in conversational context
6363
template <long num_logits, typename SUBNET>
@@ -66,7 +66,6 @@ namespace dlib
6666
// Chatbot model configuration
6767
template<
6868
long vocab_size = 2000,
69-
long max_seq_len = 128,
7069
long num_layers = 3,
7170
long num_heads = 6,
7271
long embedding_dim = 192,
@@ -75,7 +74,6 @@ namespace dlib
7574
>
7675
struct chatbot_config {
7776
static constexpr long VOCAB_SIZE = vocab_size;
78-
static constexpr long MAX_SEQ_LEN = max_seq_len;
7977
static constexpr long NUM_LAYERS = num_layers;
8078
static constexpr long NUM_HEADS = num_heads;
8179
static constexpr long EMBEDDING_DIM = embedding_dim;
@@ -90,13 +88,13 @@ namespace dlib
9088
// Network component definitions for training (with dropout)
9189
template <typename SUBNET>
9290
using t_transformer_block =
93-
trans_moe_block<activation_func, dropout_policy, MAX_SEQ_LEN, EMBEDDING_DIM, NUM_HEADS,
91+
trans_moe_block<activation_func, dropout_policy, EMBEDDING_DIM, NUM_HEADS,
9492
training_mode_tag, SUBNET>;
9593

9694
// Network component definitions for inference (using multiply)
9795
template <typename SUBNET>
9896
using i_transformer_block =
99-
trans_moe_block<activation_func, multiply, MAX_SEQ_LEN, EMBEDDING_DIM, NUM_HEADS,
97+
trans_moe_block<activation_func, multiply, EMBEDDING_DIM, NUM_HEADS,
10098
inference_mode_tag, SUBNET>;
10199

102100
// Complete network type selector based on training/inference mode
@@ -117,7 +115,6 @@ namespace dlib
117115
<< "- Layers: " << NUM_LAYERS << " transformer layers with MoE\n"
118116
<< "- Attention heads: " << NUM_HEADS << "\n"
119117
<< "- Embedding dimension: " << EMBEDDING_DIM << "\n"
120-
<< "- Context window: " << MAX_SEQ_LEN << " tokens\n"
121118
<< "- Experts per layer: 4 (auto top-n selection)";
122119
return ss.str();
123120
}
@@ -246,7 +243,7 @@ int main(int argc, char** argv)
246243
// Configuration parameters
247244
const long vocab_size = 3500;
248245
const long max_seq_len = 128;
249-
using config = chatbot_config<vocab_size, max_seq_len>;
246+
using config = chatbot_config<vocab_size>;
250247
using train_net = config::network_type<true>;
251248
using infer_net = config::network_type<false>;
252249
cout << config::model_info::describe() << "\n\n";

examples/slm_mixture_of_experts_ex.cpp

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -57,16 +57,16 @@ namespace dlib
5757
Architecture:
5858
1. Multi-head self-attention (from canonical_transformer)
5959
2. MoE feed-forward layer with multiple expert networks
60-
60+
f
6161
This replaces the standard transformer feed-forward layer with a
6262
mixture-of-experts that can specialize different experts for different
6363
types of patterns in the input.
6464
!*/
6565
template <template <typename> class ACT, template <typename> class DO,
66-
long seq_len, long d_model, long num_heads, typename MODE, typename SUBNET>
66+
long d_model, long num_heads, typename MODE, typename SUBNET>
6767
using trans_moe_block =
6868
moe_ffn<expert_net_type<DO, d_model>, 4, 0, MODE, DO,
69-
add_prev1<multihead_attention<ACT, DO, seq_len, d_model, num_heads, rms_norm<tag1<SUBNET>>>>>;
69+
add_prev1<multihead_attention<ACT, DO, d_model, num_heads, rms_norm<tag1<SUBNET>>>>>;
7070

7171
/*!
7272
Classification head for next-token prediction.
@@ -80,7 +80,6 @@ namespace dlib
8080
long num_layers = 6,
8181
long num_heads = 8,
8282
long embedding_dim = 512,
83-
long max_seq_len = 300,
8483
template <typename> class activation_func = gelu,
8584
template <typename> class dropout_policy = dropout_10
8685
>
@@ -89,7 +88,6 @@ namespace dlib
8988
static constexpr long NUM_LAYERS = num_layers;
9089
static constexpr long NUM_HEADS = num_heads;
9190
static constexpr long EMBEDDING_DIM = embedding_dim;
92-
static constexpr long MAX_SEQ_LEN = max_seq_len;
9391

9492
struct validation {
9593
static_assert(VOCAB_SIZE > 0, "Vocabulary size must be positive");
@@ -101,13 +99,13 @@ namespace dlib
10199
// Network component definitions for training (with dropout)
102100
template <typename SUBNET>
103101
using t_transformer_block =
104-
trans_moe_block<activation_func, dropout_policy, MAX_SEQ_LEN, EMBEDDING_DIM, NUM_HEADS,
102+
trans_moe_block<activation_func, dropout_policy, EMBEDDING_DIM, NUM_HEADS,
105103
training_mode_tag, SUBNET>;
106104

107105
// Network component definitions for inference (using multiply)
108106
template <typename SUBNET>
109107
using i_transformer_block =
110-
trans_moe_block<activation_func, multiply, MAX_SEQ_LEN, EMBEDDING_DIM, NUM_HEADS,
108+
trans_moe_block<activation_func, multiply, EMBEDDING_DIM, NUM_HEADS,
111109
inference_mode_tag, SUBNET>;
112110

113111
// Complete network type selector based on training/inference mode
@@ -128,7 +126,6 @@ namespace dlib
128126
<< "- Layers: " << NUM_LAYERS << "\n"
129127
<< "- Attention heads: " << NUM_HEADS << "\n"
130128
<< "- Embedding dimension: " << EMBEDDING_DIM << "\n"
131-
<< "- Sequence length: " << MAX_SEQ_LEN << "\n"
132129
<< "- Architecture: Transformer with MoE feed-forward layers\n"
133130
<< "- Experts per layer: 4 (auto top-n selection)";
134131
return ss.str();
@@ -602,8 +599,7 @@ int main(int argc, char** argv)
602599
num_tokens, // vocab_size
603600
num_layers, // number of layers
604601
num_heads, // number of attention heads
605-
embedding_dim, // embedding dimension
606-
max_seq_len // maximum sequence length
602+
embedding_dim // embedding dimension
607603
> ;
608604

609605
// Load internal dataset
@@ -938,19 +934,20 @@ int main(int argc, char** argv)
938934
<< tokenized_segments.size() << ") for generation\n";
939935
const auto& selected_segment = tokenized_segments[segment_idx];
940936

937+
long prompt_seq_len = max_seq_len;
941938
if (selected_segment.size() < (size_t)max_seq_len) {
942-
cerr << "Error: Selected segment has only " << selected_segment.size()
939+
cerr << "Warning: Selected segment has only " << selected_segment.size()
943940
<< " tokens, need at least " << max_seq_len << ".\n";
944-
return 0;
941+
prompt_seq_len = (selected_segment.size() * 2) / 3;
945942
}
946943

947-
// Extract prompt tokens (first max_seq_len tokens of the segment)
944+
// Extract prompt tokens (first prompt_seq_len tokens of the segment)
948945
std::vector<int> prompt_tokens(selected_segment.begin(),
949-
selected_segment.begin() + max_seq_len);
946+
selected_segment.begin() + prompt_seq_len);
950947
cout << "Using " << prompt_tokens.size() << " tokens for initial prompt.\n";
951948

952949
// Setup inference context
953-
inference_context llm_context(max_seq_len, 4, tokenizer.get_special_token_id("<pad>"));
950+
inference_context llm_context(max_seq_len*2, 4, tokenizer.get_special_token_id("<pad>"));
954951
llm_context.add_tokens(prompt_tokens);
955952
auto input_seq = llm_context.get_input_window();
956953

@@ -969,7 +966,7 @@ int main(int argc, char** argv)
969966
cout << "Starting autoregressive generation...\n";
970967

971968
// Generation parameters
972-
const size_t tokens_to_generate = selected_segment.size() - max_seq_len;
969+
const size_t tokens_to_generate = selected_segment.size() - prompt_seq_len;
973970
std::vector<int> generated_tokens;
974971
generated_tokens.reserve(tokens_to_generate);
975972

@@ -1021,7 +1018,7 @@ int main(int argc, char** argv)
10211018
cout << "\n=== Validation: comparing generated vs. original segment ===\n";
10221019

10231020
// Extract reference tokens (the part we tried to regenerate)
1024-
std::vector<int> reference_tokens(selected_segment.begin() + max_seq_len,
1021+
std::vector<int> reference_tokens(selected_segment.begin() + prompt_seq_len,
10251022
selected_segment.end());
10261023

10271024
// Limit comparison to the length of generated tokens

0 commit comments

Comments
 (0)