@@ -2088,6 +2088,126 @@ namespace dlib
20882088 row_stride, col_stride, add_to);
20892089 }
20902090
2091+ // ----------------------------------------------------------------------------------------
2092+
2093+ __global__ void _cuda_embeddings (size_t dsize, size_t dk, size_t dr, size_t dc,
2094+ float * d, const float * s, const float * e, size_t es
2095+ )
2096+ {
2097+ for (auto i : grid_stride_range (0 , dsize))
2098+ {
2099+ const auto n = i / (dk * dr * dc);
2100+ const auto s_idx = i % (dk * dr * dc);
2101+ const auto k = (s_idx / (dr * dc)) % dk;
2102+ const auto r = (s_idx / dc) % dr;
2103+ const auto c = s_idx % dc;
2104+
2105+ const unsigned long t_idx = static_cast <unsigned long >(s[(n * dk + k) * dr + r]);
2106+
2107+ if (t_idx < es)
2108+ d[i] = e[t_idx * dc + c];
2109+ else
2110+ d[i] = 0 .0f ;
2111+ }
2112+ }
2113+
2114+ void embeddings (
2115+ resizable_tensor& dest,
2116+ const tensor& src,
2117+ const tensor& embs
2118+ )
2119+ {
2120+ DLIB_CASSERT (
2121+ src.nr () > 0 &&
2122+ embs.num_samples () > 0 &&
2123+ embs.k () > 0 &&
2124+ embs.nr () == 1 &&
2125+ embs.nc () == 1 ,
2126+ " \n src.num_samples(): " << src.num_samples () <<
2127+ " \n src.k(): " << src.k () <<
2128+ " \n src.nr(): " << src.nr () <<
2129+ " \n src.nc(): " << src.nc () <<
2130+ " \n embs.num_samples(): " << embs.num_samples () <<
2131+ " \n embs.k(): " << embs.k () <<
2132+ " \n embs.nr(): " << embs.nr () <<
2133+ " \n embs.nc(): " << embs.nc ()
2134+ );
2135+
2136+ const long dk = dest.k ();
2137+ const long dr = dest.nr ();
2138+ const long dc = dest.nc ();
2139+
2140+ launch_kernel (_cuda_embeddings, dest.size (), dk, dr, dc,
2141+ dest.device (), src.device (), embs.device (), embs.num_samples ());
2142+ }
2143+
2144+ __global__ void _cuda_embeddings_gradient (size_t ssize, size_t sk, size_t sr, size_t sc,
2145+ const float * o, const float * gi, float * g, const float * f, float lr, bool sl, size_t es
2146+ )
2147+ {
2148+ for (auto i : grid_stride_range (0 , ssize))
2149+ {
2150+ const auto n = i / (sk * sr * sc);
2151+ const auto s_idx = i % (sk * sr * sc);
2152+ const auto k = (s_idx / (sr * sc)) % sk;
2153+ const auto r = (s_idx / sc) % sr;
2154+ const auto c = s_idx % sc;
2155+
2156+ const unsigned long t_idx = static_cast <unsigned long >(o[(n * sk + k) * sr + r]);
2157+ if (t_idx < es)
2158+ {
2159+ const float f_t = f[t_idx];
2160+ float f_s = 1 .0f ;
2161+
2162+ if (sl && f_t != 0 .0f ) f_s = fminf (0 .15f , fmaxf (1 .0f / f_t , 1 .0f ));
2163+ if (f_t > 1 ) atomicAdd (&g[t_idx * sc + c], -gi[i] * lr * f_s);
2164+ else g[t_idx * sc + c] -= gi[i] * lr * f_s;
2165+ }
2166+ }
2167+ }
2168+
2169+ void embeddings_gradient (
2170+ const tensor& prev,
2171+ const tensor& gradient_input,
2172+ tensor& grads,
2173+ const tensor& freqs,
2174+ float learning_rate,
2175+ bool scale
2176+ )
2177+ {
2178+ DLIB_CASSERT (
2179+ prev.nr () > 0 &&
2180+ gradient_input.num_samples () == prev.num_samples () &&
2181+ gradient_input.k () == prev.k () &&
2182+ gradient_input.nr () == prev.nr () &&
2183+ gradient_input.nc () == grads.k () &&
2184+ grads.num_samples () > 0 &&
2185+ grads.k () > 0 &&
2186+ grads.nr () == 1 &&
2187+ grads.nc () == 1 ,
2188+ " \n gradient_input.num_samples(): " << gradient_input.num_samples () <<
2189+ " \n gradient_input.k(): " << gradient_input.k () <<
2190+ " \n gradient_input.nr(): " << gradient_input.nr () <<
2191+ " \n gradient_input.nc(): " << gradient_input.nc () <<
2192+ " \n prev.num_samples(): " << prev.num_samples () <<
2193+ " \n prev.k(): " << prev.k () <<
2194+ " \n prev.nr(): " << prev.nr () <<
2195+ " \n prev.nc(): " << prev.nc () <<
2196+ " \n grads.num_samples(): " << grads.num_samples () <<
2197+ " \n grads.k(): " << grads.k () <<
2198+ " \n grads.nr(): " << grads.nr () <<
2199+ " \n grads.nc(): " << grads.nc ()
2200+ );
2201+
2202+ const long sk = gradient_input.k ();
2203+ const long sr = gradient_input.nr ();
2204+ const long sc = gradient_input.nc ();
2205+
2206+ launch_kernel (_cuda_embeddings_gradient, gradient_input.size (), sk, sr, sc,
2207+ prev.device (), gradient_input.device (), grads.device (), freqs.device (),
2208+ learning_rate, scale, grads.num_samples ());
2209+ }
2210+
20912211 // ----------------------------------------------------------------------------------------
20922212
20932213 __global__ void _cuda_layer_normalize (
0 commit comments