Skip to content

Commit 8567075

Browse files
Cydralarrufatdavisking
authored
Add embeddings_ layer and supporting utility functions (#3021)
* Fix Stride Indexing Bugs in `reorg` and `reorg_gradient` Functions (CPU & CUDA) and Add `add_to` Parameter * 'add_to' parameter missing in cuda call reorg_gradient.launch_kernel() * Cleanup: remove using namespace std; (#3016) * remove using namespace std from headers * more std:: * more std:: * more std:: on windows stuff * remove uses of using namespace std::chrono * do not use C++17 features * Add Davis suggestion * revert some more stuff * revert removing include * more std::chrono stuff * fix build error * Adjust comment formatting to be like other dlib comments * Add positional encodings layer to Dlib * Implement embeddings_ layer and add supporting utility functions to tensor_tools.h * Updates * Updates * Updates * Updates * Update * Update dlib/cuda/tensor_tools.h --------- Co-authored-by: Adrià <[email protected]> Co-authored-by: Davis King <[email protected]> Co-authored-by: Davis E. King <[email protected]>
1 parent 488ee5c commit 8567075

File tree

11 files changed

+982
-10
lines changed

11 files changed

+982
-10
lines changed

.vscode/settings.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{
2+
"githubPullRequests.ignoredPullRequestBranches": [
3+
"master"
4+
]
5+
}

dlib/cuda/cpu_dlib.cpp

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2421,6 +2421,121 @@ namespace dlib
24212421
}
24222422

24232423
// ------------------------------------------------------------------------------------
2424+
2425+
void embeddings(
2426+
resizable_tensor& dest,
2427+
const tensor& src,
2428+
const tensor& embs
2429+
)
2430+
{
2431+
DLIB_CASSERT(
2432+
src.nr() > 0 &&
2433+
embs.num_samples() > 0 &&
2434+
embs.k() > 0 &&
2435+
embs.nr() == 1 &&
2436+
embs.nc() == 1,
2437+
"\nsrc.num_samples(): " << src.num_samples() <<
2438+
"\nsrc.k(): " << src.k() <<
2439+
"\nsrc.nr(): " << src.nr() <<
2440+
"\nsrc.nc(): " << src.nc() <<
2441+
"\nembs.num_samples(): " << embs.num_samples() <<
2442+
"\nembs.k(): " << embs.k() <<
2443+
"\nembs.nr(): " << embs.nr() <<
2444+
"\nembs.nc(): " << embs.nc()
2445+
);
2446+
2447+
long ns = dest.num_samples(), nk = dest.k(), nr = dest.nr(), nc = dest.nc();
2448+
const float* src_data = src.host();
2449+
float* dest_data = dest.host();
2450+
const float* embs_data = embs.host();
2451+
for (long s = 0; s < ns; ++s)
2452+
{
2453+
for (long k = 0; k < nk; ++k)
2454+
{
2455+
for (long r = 0; r < nr; ++r)
2456+
{
2457+
const unsigned long token_idx = static_cast<unsigned long>(src_data[tensor_index(src, s, k, r, 0)]);
2458+
if (token_idx < embs.num_samples())
2459+
{
2460+
for (long c = 0; c < nc; ++c)
2461+
dest_data[tensor_index(dest, s, k, r, c)] = embs_data[tensor_index(embs, token_idx, c, 0, 0)];
2462+
}
2463+
else
2464+
{
2465+
for (long c = 0; c < nc; ++c)
2466+
dest_data[tensor_index(dest, s, k, r, c)] = 0;
2467+
}
2468+
}
2469+
}
2470+
}
2471+
}
2472+
2473+
void embeddings_gradient(
2474+
const tensor& prev,
2475+
const tensor& gradient_input,
2476+
tensor& grads,
2477+
const tensor& freqs,
2478+
float learning_rate,
2479+
bool scale
2480+
)
2481+
{
2482+
DLIB_CASSERT(
2483+
prev.nr() > 0 &&
2484+
gradient_input.num_samples() == prev.num_samples() &&
2485+
gradient_input.k() == prev.k() &&
2486+
gradient_input.nr() == prev.nr() &&
2487+
gradient_input.nc() == grads.k() &&
2488+
grads.num_samples() > 0 &&
2489+
grads.k() > 0 &&
2490+
grads.nr() == 1 &&
2491+
grads.nc() == 1,
2492+
"\ngradient_input.num_samples(): " << gradient_input.num_samples() <<
2493+
"\ngradient_input.k(): " << gradient_input.k() <<
2494+
"\ngradient_input.nr(): " << gradient_input.nr() <<
2495+
"\ngradient_input.nc(): " << gradient_input.nc() <<
2496+
"\nprev.num_samples(): " << prev.num_samples() <<
2497+
"\nprev.k(): " << prev.k() <<
2498+
"\nprev.nr(): " << prev.nr() <<
2499+
"\nprev.nc(): " << prev.nc() <<
2500+
"\ngrads.num_samples(): " << grads.num_samples() <<
2501+
"\ngrads.k(): " << grads.k() <<
2502+
"\ngrads.nr(): " << grads.nr() <<
2503+
"\ngrads.nc(): " << grads.nc()
2504+
);
2505+
2506+
const float* prev_data = prev.host();
2507+
const float* gradient_input_data = gradient_input.host();
2508+
const float* freqs_data = freqs.host();
2509+
float* grads_data = grads.host();
2510+
long ns = gradient_input.num_samples(), nk = gradient_input.k();
2511+
long nr = gradient_input.nr(), nc = gradient_input.nc();
2512+
2513+
std::vector<dlib::mutex> embedding_mutexes(grads.num_samples());
2514+
parallel_for(0, ns * nk, [&](long i)
2515+
{
2516+
long s = i / nk;
2517+
long k = i % nk;
2518+
2519+
for (long r = 0; r < nr; ++r)
2520+
{
2521+
const unsigned long token_idx = static_cast<unsigned long>(prev_data[tensor_index(prev, s, k, r, 0)]);
2522+
if (token_idx < grads.num_samples())
2523+
{
2524+
const float freg_token = freqs_data[token_idx];
2525+
float freq_scale = 1.0f;
2526+
2527+
if (scale && freg_token != 0.0f) freq_scale = std::min(0.15f, std::max(1.0f / freg_token, 1.0f));
2528+
auto_mutex locker(embedding_mutexes[token_idx]);
2529+
for (long c = 0; c < nc; ++c)
2530+
{
2531+
const float gradient = gradient_input_data[tensor_index(gradient_input, s, k, r, c)];
2532+
grads_data[tensor_index(grads, token_idx, c, 0, 0)] -= (gradient * learning_rate * freq_scale);
2533+
}
2534+
}
2535+
}
2536+
});
2537+
}
2538+
24242539
// ------------------------------------------------------------------------------------
24252540
// ------------------------------------------------------------------------------------
24262541

dlib/cuda/cpu_dlib.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,23 @@ namespace dlib
517517
const tensor& gradient_input
518518
);
519519

520+
// -----------------------------------------------------------------------------------
521+
522+
void embeddings(
523+
resizable_tensor& dest,
524+
const tensor& src,
525+
const tensor& embs
526+
);
527+
528+
void embeddings_gradient(
529+
const tensor& prev,
530+
const tensor& gradient_input,
531+
tensor& grads,
532+
const tensor& freqs,
533+
float learning_rate,
534+
bool scale
535+
);
536+
520537
// -----------------------------------------------------------------------------------
521538

522539
class pooling

dlib/cuda/cuda_dlib.cu

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2088,6 +2088,126 @@ namespace dlib
20882088
row_stride, col_stride, add_to);
20892089
}
20902090

2091+
// ----------------------------------------------------------------------------------------
2092+
2093+
__global__ void _cuda_embeddings(size_t dsize, size_t dk, size_t dr, size_t dc,
2094+
float* d, const float* s, const float* e, size_t es
2095+
)
2096+
{
2097+
for (auto i : grid_stride_range(0, dsize))
2098+
{
2099+
const auto n = i / (dk * dr * dc);
2100+
const auto s_idx = i % (dk * dr * dc);
2101+
const auto k = (s_idx / (dr * dc)) % dk;
2102+
const auto r = (s_idx / dc) % dr;
2103+
const auto c = s_idx % dc;
2104+
2105+
const unsigned long t_idx = static_cast<unsigned long>(s[(n * dk + k) * dr + r]);
2106+
2107+
if (t_idx < es)
2108+
d[i] = e[t_idx * dc + c];
2109+
else
2110+
d[i] = 0.0f;
2111+
}
2112+
}
2113+
2114+
void embeddings(
2115+
resizable_tensor& dest,
2116+
const tensor& src,
2117+
const tensor& embs
2118+
)
2119+
{
2120+
DLIB_CASSERT(
2121+
src.nr() > 0 &&
2122+
embs.num_samples() > 0 &&
2123+
embs.k() > 0 &&
2124+
embs.nr() == 1 &&
2125+
embs.nc() == 1,
2126+
"\nsrc.num_samples(): " << src.num_samples() <<
2127+
"\nsrc.k(): " << src.k() <<
2128+
"\nsrc.nr(): " << src.nr() <<
2129+
"\nsrc.nc(): " << src.nc() <<
2130+
"\nembs.num_samples(): " << embs.num_samples() <<
2131+
"\nembs.k(): " << embs.k() <<
2132+
"\nembs.nr(): " << embs.nr() <<
2133+
"\nembs.nc(): " << embs.nc()
2134+
);
2135+
2136+
const long dk = dest.k();
2137+
const long dr = dest.nr();
2138+
const long dc = dest.nc();
2139+
2140+
launch_kernel(_cuda_embeddings, dest.size(), dk, dr, dc,
2141+
dest.device(), src.device(), embs.device(), embs.num_samples());
2142+
}
2143+
2144+
__global__ void _cuda_embeddings_gradient(size_t ssize, size_t sk, size_t sr, size_t sc,
2145+
const float* o, const float* gi, float* g, const float* f, float lr, bool sl, size_t es
2146+
)
2147+
{
2148+
for (auto i : grid_stride_range(0, ssize))
2149+
{
2150+
const auto n = i / (sk * sr * sc);
2151+
const auto s_idx = i % (sk * sr * sc);
2152+
const auto k = (s_idx / (sr * sc)) % sk;
2153+
const auto r = (s_idx / sc) % sr;
2154+
const auto c = s_idx % sc;
2155+
2156+
const unsigned long t_idx = static_cast<unsigned long>(o[(n * sk + k) * sr + r]);
2157+
if (t_idx < es)
2158+
{
2159+
const float f_t = f[t_idx];
2160+
float f_s = 1.0f;
2161+
2162+
if (sl && f_t != 0.0f) f_s = fminf(0.15f, fmaxf(1.0f / f_t, 1.0f));
2163+
if (f_t > 1) atomicAdd(&g[t_idx * sc + c], -gi[i] * lr * f_s);
2164+
else g[t_idx * sc + c] -= gi[i] * lr * f_s;
2165+
}
2166+
}
2167+
}
2168+
2169+
void embeddings_gradient(
2170+
const tensor& prev,
2171+
const tensor& gradient_input,
2172+
tensor& grads,
2173+
const tensor& freqs,
2174+
float learning_rate,
2175+
bool scale
2176+
)
2177+
{
2178+
DLIB_CASSERT(
2179+
prev.nr() > 0 &&
2180+
gradient_input.num_samples() == prev.num_samples() &&
2181+
gradient_input.k() == prev.k() &&
2182+
gradient_input.nr() == prev.nr() &&
2183+
gradient_input.nc() == grads.k() &&
2184+
grads.num_samples() > 0 &&
2185+
grads.k() > 0 &&
2186+
grads.nr() == 1 &&
2187+
grads.nc() == 1,
2188+
"\ngradient_input.num_samples(): " << gradient_input.num_samples() <<
2189+
"\ngradient_input.k(): " << gradient_input.k() <<
2190+
"\ngradient_input.nr(): " << gradient_input.nr() <<
2191+
"\ngradient_input.nc(): " << gradient_input.nc() <<
2192+
"\nprev.num_samples(): " << prev.num_samples() <<
2193+
"\nprev.k(): " << prev.k() <<
2194+
"\nprev.nr(): " << prev.nr() <<
2195+
"\nprev.nc(): " << prev.nc() <<
2196+
"\ngrads.num_samples(): " << grads.num_samples() <<
2197+
"\ngrads.k(): " << grads.k() <<
2198+
"\ngrads.nr(): " << grads.nr() <<
2199+
"\ngrads.nc(): " << grads.nc()
2200+
);
2201+
2202+
const long sk = gradient_input.k();
2203+
const long sr = gradient_input.nr();
2204+
const long sc = gradient_input.nc();
2205+
2206+
launch_kernel(_cuda_embeddings_gradient, gradient_input.size(), sk, sr, sc,
2207+
prev.device(), gradient_input.device(), grads.device(), freqs.device(),
2208+
learning_rate, scale, grads.num_samples());
2209+
}
2210+
20912211
// ----------------------------------------------------------------------------------------
20922212

20932213
__global__ void _cuda_layer_normalize(

dlib/cuda/cuda_dlib.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,23 @@ namespace dlib
561561
const tensor& gradient_input
562562
);
563563

564+
// -----------------------------------------------------------------------------------
565+
566+
void embeddings(
567+
resizable_tensor& dest,
568+
const tensor& src,
569+
const tensor& embs
570+
);
571+
572+
void embeddings_gradient(
573+
const tensor& prev,
574+
const tensor& gradient_input,
575+
tensor& grads,
576+
const tensor& freqs,
577+
float learning_rate,
578+
bool scale
579+
);
580+
564581
// ----------------------------------------------------------------------------------------
565582

566583
void copy_tensor(

dlib/cuda/tensor_tools.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1296,6 +1296,37 @@ namespace dlib { namespace tt
12961296
#endif
12971297
}
12981298

1299+
// ----------------------------------------------------------------------------------------
1300+
1301+
void embeddings(
1302+
resizable_tensor& dest,
1303+
const tensor& src,
1304+
const tensor& embs
1305+
)
1306+
{
1307+
#ifdef DLIB_USE_CUDA
1308+
cuda::embeddings(dest, src, embs);
1309+
#else
1310+
cpu::embeddings(dest, src, embs);
1311+
#endif
1312+
}
1313+
1314+
void embeddings_gradient(
1315+
const tensor& prev,
1316+
const tensor& gradient_input,
1317+
tensor& grads,
1318+
const tensor& freqs,
1319+
float learning_rate,
1320+
bool scale
1321+
)
1322+
{
1323+
#ifdef DLIB_USE_CUDA
1324+
cuda::embeddings_gradient(prev, gradient_input, grads, freqs, learning_rate, scale);
1325+
#else
1326+
cpu::embeddings_gradient(prev, gradient_input, grads, freqs, learning_rate, scale);
1327+
#endif
1328+
}
1329+
12991330
// ----------------------------------------------------------------------------------------
13001331

13011332
}}

0 commit comments

Comments
 (0)