Skip to content

Commit 85eb2c9

Browse files
committed
Update
1 parent 44830f7 commit 85eb2c9

File tree

6 files changed

+205
-157
lines changed

6 files changed

+205
-157
lines changed

dlib/cuda/cpu_dlib.h

Lines changed: 65 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -772,88 +772,94 @@ namespace dlib
772772

773773
class compute_loss_cross_entropy_per_logit
774774
{
775-
/*! The point of this class is to compute the loss for loss_cross_entropy_per_logit_
776-
on the cpu to provide an analogous implementation of the cuda version
775+
/*!
776+
Computes cross-entropy loss for causal language modeling
777+
Uses all sequence positions (except last) for training
778+
Each position t predicts the token at position t+1
777779
!*/
778780
public:
779-
compute_loss_cross_entropy_per_logit()
780-
{
781-
}
782-
781+
compute_loss_cross_entropy_per_logit() {}
782+
783783
template <typename const_label_iterator>
784784
void operator()(
785785
const_label_iterator truth,
786+
const tensor& input_tensor,
786787
const tensor& output_tensor,
787788
tensor& grad,
788789
double& loss
789790
) const
790791
{
791-
DLIB_CASSERT(output_tensor.k() == 1,
792-
"output_tensor.k() = " << output_tensor.k());
793-
792+
DLIB_CASSERT(output_tensor.k() == 1);
793+
DLIB_CASSERT(input_tensor.k() == 1);
794+
DLIB_CASSERT(input_tensor.nc() == 1);
795+
794796
const long batch_size = output_tensor.num_samples();
795797
const long seq_len = output_tensor.nr();
796798
const long vocab_size = output_tensor.nc();
797-
798-
// The loss we output is the average loss over the mini-batch
799-
const double scale = 1.0 / batch_size;
799+
800+
// Normalization over all positions
801+
const double scale = 1.0 / (batch_size * seq_len);
802+
800803
loss = 0.0;
801-
802804
const float* out_data = output_tensor.host();
805+
const float* in_data = input_tensor.host();
803806
float* g = grad.host();
804-
805-
// Zero out all gradients first. Gradients will only be non-zero at the
806-
// last position (seq_len-1) of each sequence where the loss is computed
807+
807808
std::fill(g, g + grad.size(), 0.0f);
808-
809-
// Compute loss and gradients only for the last position of each sequence.
810-
// This implements the standard next token prediction objective used in
811-
// autoregressive language models
809+
812810
for (long i = 0; i < batch_size; ++i)
813811
{
814-
const unsigned long target_class = *(truth + i);
815-
816-
// The network must produce a number of outputs that is equal to the number
817-
// of labels when using this type of loss
818-
DLIB_CASSERT(target_class < static_cast<unsigned long>(vocab_size),
819-
"target_class: " << target_class << ", vocab_size: " << vocab_size);
820-
821-
// Compute softmax for numerical stability using the log-sum-exp trick.
822-
// First, find the maximum value for this position to prevent overflow
823-
float max_val = out_data[tensor_index(output_tensor, i, 0, seq_len - 1, 0)];
824-
for (long c = 1; c < vocab_size; ++c)
825-
{
826-
const float val = out_data[tensor_index(output_tensor, i, 0, seq_len - 1, c)];
827-
max_val = std::max(max_val, val);
828-
}
829-
830-
// Compute exp(x - max) and sum for the softmax denominator
831-
float sum_exp = 0;
832-
for (long c = 0; c < vocab_size; ++c)
833-
{
834-
const unsigned long idx = tensor_index(output_tensor, i, 0, seq_len - 1, c);
835-
const float exp_val = std::exp(out_data[idx] - max_val);
836-
g[idx] = exp_val; // Temporarily store exp values
837-
sum_exp += exp_val;
838-
}
839-
840-
// Normalize to get softmax probabilities, compute loss, and set gradients
841-
for (long c = 0; c < vocab_size; ++c)
812+
// Loop over all positions (0 to seq_len-1)
813+
for (long t = 0; t < seq_len; ++t)
842814
{
843-
const unsigned long idx = tensor_index(output_tensor, i, 0, seq_len - 1, c);
844-
const float softmax_val = g[idx] / sum_exp;
845-
846-
if (static_cast<unsigned long>(c) == target_class)
815+
unsigned long target_class;
816+
817+
// Extract target token
818+
if (t < seq_len - 1) {
819+
// For positions 0 to seq_len-2: target from input_tensor[t+1]
820+
target_class = static_cast<unsigned long>(
821+
in_data[tensor_index(input_tensor, i, 0, t + 1, 0)]
822+
);
823+
} else {
824+
// For last position (seq_len-1): target from truth
825+
target_class = *(truth + i);
826+
}
827+
828+
DLIB_CASSERT(target_class < static_cast<unsigned long>(vocab_size));
829+
830+
// Find max logit for numerical stability
831+
float max_val = out_data[tensor_index(output_tensor, i, 0, t, 0)];
832+
for (long c = 1; c < vocab_size; ++c)
847833
{
848-
// Cross-entropy loss: -log(p(target_class))
849-
loss += scale * (-std::log(std::max(softmax_val, 1e-10f)));
850-
// Gradient for the target class: scale * (p - 1)
851-
g[idx] = scale * (softmax_val - 1.0f);
834+
const float val = out_data[tensor_index(output_tensor, i, 0, t, c)];
835+
max_val = std::max(max_val, val);
852836
}
853-
else
837+
838+
// Compute softmax denominator
839+
float sum_exp = 0.0f;
840+
for (long c = 0; c < vocab_size; ++c)
854841
{
855-
// Gradient for non-target classes: scale * p
856-
g[idx] = scale * softmax_val;
842+
const unsigned long idx = tensor_index(output_tensor, i, 0, t, c);
843+
const float exp_val = std::exp(out_data[idx] - max_val);
844+
g[idx] = exp_val;
845+
sum_exp += exp_val;
846+
}
847+
848+
// Compute loss and gradients
849+
for (long c = 0; c < vocab_size; ++c)
850+
{
851+
const unsigned long idx = tensor_index(output_tensor, i, 0, t, c);
852+
const float softmax_val = g[idx] / sum_exp;
853+
854+
if (static_cast<unsigned long>(c) == target_class)
855+
{
856+
loss += scale * (-std::log(std::max(softmax_val, 1e-10f)));
857+
g[idx] = scale * (softmax_val - 1.0f);
858+
}
859+
else
860+
{
861+
g[idx] = scale * softmax_val;
862+
}
857863
}
858864
}
859865
}

dlib/cuda/cuda_dlib.cu

Lines changed: 46 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3114,6 +3114,7 @@ namespace dlib
31143114
float* loss_out,
31153115
float* g,
31163116
const unsigned long* truth,
3117+
const float* input_data,
31173118
const float* out_data,
31183119
size_t batch_size,
31193120
size_t seq_len,
@@ -3125,54 +3126,63 @@ namespace dlib
31253126

31263127
for (auto sample_idx : grid_stride_range(0, batch_size))
31273128
{
3128-
const unsigned long target_class = truth[sample_idx];
3129-
3130-
const size_t last_pos = seq_len - 1;
3131-
3132-
float max_val = out_data[sample_idx * seq_len * vocab_size + last_pos * vocab_size + 0];
3133-
for (size_t c = 1; c < vocab_size; ++c)
3129+
for (size_t t = 0; t < seq_len; ++t)
31343130
{
3135-
const size_t idx = sample_idx * seq_len * vocab_size + last_pos * vocab_size + c;
3136-
max_val = ::max(max_val, out_data[idx]);
3137-
}
3131+
unsigned long target_class;
31383132

3139-
float sum_exp = 0.0f;
3140-
for (size_t c = 0; c < vocab_size; ++c)
3141-
{
3142-
const size_t idx = sample_idx * seq_len * vocab_size + last_pos * vocab_size + c;
3143-
const float exp_val = ::exp(out_data[idx] - max_val);
3144-
g[idx] = exp_val;
3145-
sum_exp += exp_val;
3146-
}
3133+
if (t < seq_len - 1) {
3134+
const size_t input_idx = sample_idx * seq_len + (t + 1);
3135+
target_class = static_cast<unsigned long>(input_data[input_idx]);
3136+
}
3137+
else {
3138+
target_class = truth[sample_idx];
3139+
}
31473140

3148-
for (size_t c = 0; c < vocab_size; ++c)
3149-
{
3150-
const size_t idx = sample_idx * seq_len * vocab_size + last_pos * vocab_size + c;
3151-
const float softmax_val = g[idx] / sum_exp;
3141+
const size_t base_idx = sample_idx * seq_len * vocab_size + t * vocab_size;
3142+
float max_val = out_data[base_idx + 0];
3143+
for (size_t c = 1; c < vocab_size; ++c)
3144+
{
3145+
max_val = ::max(max_val, out_data[base_idx + c]);
3146+
}
31523147

3153-
if (c == target_class)
3148+
float sum_exp = 0.0f;
3149+
for (size_t c = 0; c < vocab_size; ++c)
31543150
{
3155-
total_loss += -::log(::max(softmax_val, 1e-10f));
3156-
g[idx] = scale * (softmax_val - 1.0f);
3151+
const size_t idx = base_idx + c;
3152+
const float exp_val = ::exp(out_data[idx] - max_val);
3153+
g[idx] = exp_val;
3154+
sum_exp += exp_val;
31573155
}
3158-
else
3156+
3157+
for (size_t c = 0; c < vocab_size; ++c)
31593158
{
3160-
g[idx] = scale * softmax_val;
3159+
const size_t idx = base_idx + c;
3160+
const float softmax_val = g[idx] / sum_exp;
3161+
3162+
if (c == target_class)
3163+
{
3164+
total_loss += -::log(::max(softmax_val, 1e-10f));
3165+
g[idx] = scale * (softmax_val - 1.0f);
3166+
}
3167+
else
3168+
{
3169+
g[idx] = scale * softmax_val;
3170+
}
31613171
}
31623172
}
31633173
}
31643174

31653175
warp_reduce_atomic_add(*loss_out, total_loss);
31663176
}
31673177

3168-
void compute_loss_cross_entropy_per_logit::
3169-
do_work(
3170-
cuda_data_ptr<float> loss_work_buffer,
3171-
cuda_data_ptr<const unsigned long> truth_buffer,
3172-
const tensor& subnetwork_output,
3173-
tensor& gradient,
3174-
double& loss
3175-
)
3178+
void compute_loss_cross_entropy_per_logit::do_work(
3179+
cuda_data_ptr<float> loss_work_buffer,
3180+
cuda_data_ptr<const unsigned long> truth_buffer,
3181+
const tensor& input_tensor,
3182+
const tensor& subnetwork_output,
3183+
tensor& gradient,
3184+
double& loss
3185+
)
31763186
{
31773187
CHECK_CUDA(cudaMemset(gradient.device(), 0, gradient.size() * sizeof(float)));
31783188
CHECK_CUDA(cudaMemset(loss_work_buffer, 0, sizeof(float)));
@@ -3181,12 +3191,13 @@ namespace dlib
31813191
const long seq_len = subnetwork_output.nr();
31823192
const long vocab_size = subnetwork_output.nc();
31833193

3184-
const double scale = 1.0 / batch_size;
3194+
const double scale = 1.0 / (batch_size * seq_len);
31853195

31863196
launch_kernel(_cuda_compute_loss_cross_entropy_per_logit, max_jobs(batch_size),
31873197
loss_work_buffer.data(),
31883198
gradient.device(),
31893199
truth_buffer.data(),
3200+
input_tensor.device(),
31903201
subnetwork_output.device(),
31913202
batch_size,
31923203
seq_len,

dlib/cuda/cuda_dlib.h

Lines changed: 38 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -667,55 +667,51 @@ namespace dlib
667667

668668
// ----------------------------------------------------------------------------------------
669669

670-
class compute_loss_cross_entropy_per_logit
671-
{
672-
/*!
673-
The point of this class is to compute the loss computed by
674-
loss_cross_entropy_per_logit_, but to do so with CUDA.
675-
!*/
676-
public:
677-
678-
compute_loss_cross_entropy_per_logit()
679-
{
680-
}
681-
682-
template <typename const_label_iterator>
683-
void operator() (
684-
const_label_iterator truth,
685-
const tensor& subnetwork_output,
686-
tensor& gradient,
687-
double& loss
688-
) const
670+
class compute_loss_cross_entropy_per_logit
689671
{
690-
const size_t bytes_per_sample = sizeof(unsigned long);
691-
buf = device_global_buffer(subnetwork_output.num_samples()*bytes_per_sample + sizeof(float));
692-
693-
cuda_data_ptr<float> loss_buf = static_pointer_cast<float>(buf, 1);
694-
buf = buf+sizeof(float);
672+
/*!
673+
The point of this class is to compute the loss computed by
674+
loss_cross_entropy_per_logit_, but to do so with CUDA
675+
!*/
676+
public:
677+
compute_loss_cross_entropy_per_logit() {}
695678

696-
for (long i = 0; i < subnetwork_output.num_samples(); ++i, ++truth)
679+
template <typename const_label_iterator>
680+
void operator() (
681+
const_label_iterator truth,
682+
const tensor& input_tensor, // Source tokens
683+
const tensor& subnetwork_output, // Logits
684+
tensor& gradient,
685+
double& loss
686+
) const
697687
{
698-
const unsigned long t = *truth;
699-
memcpy(buf + i*bytes_per_sample, &t, bytes_per_sample);
700-
}
701-
702-
auto truth_buf = static_pointer_cast<const unsigned long>(buf, subnetwork_output.num_samples());
688+
const size_t bytes_per_sample = sizeof(unsigned long);
689+
buf = device_global_buffer(subnetwork_output.num_samples() * bytes_per_sample + sizeof(float));
690+
cuda_data_ptr<float> loss_buf = static_pointer_cast<float>(buf, 1);
691+
buf = buf + sizeof(float);
703692

704-
do_work(loss_buf, truth_buf, subnetwork_output, gradient, loss);
705-
}
693+
for (long i = 0; i < subnetwork_output.num_samples(); ++i, ++truth)
694+
{
695+
const unsigned long t = *truth;
696+
memcpy(buf + i * bytes_per_sample, &t, bytes_per_sample);
697+
}
706698

707-
private:
699+
auto truth_buf = static_pointer_cast<const unsigned long>(buf, subnetwork_output.num_samples());
700+
do_work(loss_buf, truth_buf, input_tensor, subnetwork_output, gradient, loss);
701+
}
708702

709-
static void do_work(
710-
cuda_data_ptr<float> loss_work_buffer,
711-
cuda_data_ptr<const unsigned long> truth_buffer,
712-
const tensor& subnetwork_output,
713-
tensor& gradient,
714-
double& loss
715-
);
703+
private:
704+
static void do_work(
705+
cuda_data_ptr<float> loss_work_buffer,
706+
cuda_data_ptr<const unsigned long> truth_buffer,
707+
const tensor& input_tensor,
708+
const tensor& subnetwork_output,
709+
tensor& gradient,
710+
double& loss
711+
);
716712

717-
mutable cuda_data_void_ptr buf;
718-
};
713+
mutable cuda_data_void_ptr buf;
714+
};
719715

720716
// ----------------------------------------------------------------------------------------
721717

0 commit comments

Comments
 (0)